import os import time import json import numpy as np import pandas as pd from datetime import datetime import random import logging import asyncio import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from collections import deque, namedtuple from dotenv import load_dotenv import ccxt import websockets from torch.utils.tensorboard import SummaryWriter import torch.cuda.amp as amp # Add this import at the top from sklearn.preprocessing import MinMaxScaler import copy import argparse import traceback import io import matplotlib.dates as mdates from matplotlib.figure import Figure from PIL import Image import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime from dataprovider_realtime import BinanceWebSocket, BinanceHistoricalData from datetime import datetime as dt # Add Dash-related imports import dash from dash import html, dcc, callback_context from dash.dependencies import Input, Output, State import plotly.graph_objects as go from plotly.subplots import make_subplots from threading import Thread import socket # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()] ) logger = logging.getLogger("trading_bot") # Look for WebSocket specific logger websocket_logger = logging.getLogger('websocket') # or similar name websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO # Add this somewhere after the logger is defined class WebSocketFilter(logging.Filter): def filter(self, record): # Filter out DEBUG messages from WebSocket-related modules if record.levelno == logging.INFO and ('websocket' in record.name or 'protocol' in record.name or 'realtime' in record.name): return False return True logger.addFilter(WebSocketFilter()) # Load environment variables load_dotenv() MEXC_API_KEY = os.getenv('MEXC_API_KEY') MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY') # Constants INITIAL_BALANCE = 100 # USD MAX_LEVERAGE = 100 STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5% MEMORY_SIZE = 100000 BATCH_SIZE = 64 GAMMA = 0.99 # Discount factor EPSILON_START = 1.0 EPSILON_END = 0.05 EPSILON_DECAY = 10000 STATE_SIZE = 64 # Size of our state representation LEARNING_RATE = 1e-4 TARGET_UPDATE = 10 # Update target network every 10 episodes # Experience replay tuple Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done']) # Add this function near the top of the file, after the imports but before any classes def find_local_extrema(prices, window=5): """Find local minima (bottoms) and maxima (tops) in price data""" bottoms = [] tops = [] if len(prices) < window * 2 + 1: return bottoms, tops for i in range(window, len(prices) - window): # Check if this is a local minimum (bottom) if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \ all(prices[i] <= prices[i+j] for j in range(1, window+1)): bottoms.append(i) # Check if this is a local maximum (top) if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \ all(prices[i] >= prices[i+j] for j in range(1, window+1)): tops.append(i) return bottoms, tops class ReplayMemory: def __init__(self, capacity): self.memory = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.memory.append(Experience(state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class DQN(nn.Module): def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() self.state_size = state_size self.hidden_size = hidden_size self.lstm_layers = lstm_layers # Initial feature extraction self.fc1 = nn.Linear(state_size, hidden_size) # Use LayerNorm instead of BatchNorm for more stability with varying batch sizes self.ln1 = nn.LayerNorm(hidden_size) self.dropout1 = nn.Dropout(0.2) # LSTM layer for sequential data self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2) # Attention mechanism self.attention = nn.MultiheadAttention(hidden_size, attention_heads) # Output layers with increased capacity self.fc2 = nn.Linear(hidden_size, hidden_size) self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm self.dropout2 = nn.Dropout(0.2) self.fc3 = nn.Linear(hidden_size, hidden_size // 2) # Dueling DQN architecture self.value_stream = nn.Linear(hidden_size // 2, 1) self.advantage_stream = nn.Linear(hidden_size // 2, action_size) # Transformer encoder for more complex pattern recognition encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) def forward(self, x): batch_size = x.size(0) if x.dim() > 1 else 1 # Ensure input has correct shape if x.dim() == 1: x = x.unsqueeze(0) # Add batch dimension # Check if state size matches expected input size if x.size(1) != self.state_size: # Handle mismatched input by either truncating or padding if x.size(1) > self.state_size: x = x[:, :self.state_size] # Truncate else: # Pad with zeros padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device) x = torch.cat([x, padding], dim=1) # Initial feature extraction x = self.fc1(x) x = F.relu(self.ln1(x)) # LayerNorm works with any batch size x = self.dropout1(x) # Reshape for LSTM x_lstm = x.unsqueeze(1) if x.dim() == 2 else x # Process through LSTM lstm_out, _ = self.lstm(x_lstm) lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1] # Process through transformer for more complex patterns transformer_input = x.unsqueeze(1) if x.dim() == 2 else x transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1)) transformer_out = transformer_out.transpose(0, 1).mean(dim=1) # Combine LSTM and transformer outputs x = lstm_out + transformer_out # Final layers x = self.fc2(x) x = F.relu(self.ln2(x)) # LayerNorm works with any batch size x = self.dropout2(x) x = F.relu(self.fc3(x)) # Dueling architecture value = self.value_stream(x) advantages = self.advantage_stream(x) qvals = value + (advantages - advantages.mean(dim=1, keepdim=True)) return qvals class PricePredictionModel(nn.Module): def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): super(PricePredictionModel, self).__init__() self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2) self.fc = nn.Linear(hidden_size, output_size) self.scaler = MinMaxScaler(feature_range=(0, 1)) self.is_fitted = False def forward(self, x): # x shape: [batch_size, seq_len, 1] lstm_out, _ = self.lstm(x) # Use the last time step output predictions = self.fc(lstm_out[:, -1, :]) return predictions def preprocess(self, data): # Reshape data for scaler data_reshaped = np.array(data).reshape(-1, 1) # Fit scaler if not already fitted if not self.is_fitted: self.scaler.fit(data_reshaped) self.is_fitted = True # Transform data scaled_data = self.scaler.transform(data_reshaped) return scaled_data def postprocess(self, scaled_predictions): # Inverse transform to get actual price values return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten() def predict_next_candles(self, price_history, num_candles=5): if len(price_history) < 30: # Need enough history return np.zeros(num_candles) # Preprocess data scaled_data = self.preprocess(price_history) # Create sequence sequence = scaled_data[-30:].reshape(1, 30, 1) sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device) # Get predictions with torch.no_grad(): scaled_predictions = self(sequence_tensor).cpu().numpy()[0] # Postprocess predictions predictions = self.postprocess(scaled_predictions) return predictions def train_on_new_data(self, price_history, optimizer, epochs=10): if len(price_history) < 35: # Need enough history for training return 0.0 # Preprocess data scaled_data = self.preprocess(price_history) # Create sequences and targets sequences = [] targets = [] for i in range(len(scaled_data) - 35): # Sequence: 30 time steps seq = scaled_data[i:i+30] # Target: next 5 time steps target = scaled_data[i+30:i+35].flatten() sequences.append(seq) targets.append(target) if not sequences: # If no sequences were created return 0.0 # Convert to tensors sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device) targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device) # Training loop total_loss = 0 for _ in range(epochs): # Forward pass predictions = self(sequences_tensor) # Calculate loss loss = F.mse_loss(predictions, targets_tensor) # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / epochs import os import time import logging import sys import argparse import json # Add the NN directory to the Python path sys.path.append(os.path.abspath("NN")) from NN.main import load_model from NN.neural_network_orchestrator import NeuralNetworkOrchestrator from NN.realtime_data_interface import RealtimeDataInterface # Initialize logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("trading_bot.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def main(): """Main function for the trading bot.""" # Parse command-line arguments parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration") parser.add_argument('--symbols', nargs='+', default=["ETH/USDT", "ETH/USDT"], help='Trading symbols to monitor') parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"], help='Timeframes to monitor') parser.add_argument('--window-size', type=int, default=20, help='Window size for model input') parser.add_argument('--output-size', type=int, default=3, help='Output size of the model (3 for BUY/HOLD/SELL)') parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"], help='Type of neural network model') parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"], help='Trading mode') parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"], help='Exchange to use for trading') parser.add_argument('--api-key', type=str, default=None, help='API key for the exchange') parser.add_argument('--api-secret', type=str, default=None, help='API secret for the exchange') parser.add_argument('--test-mode', action='store_true', help='Use test/sandbox exchange environment') parser.add_argument('--position-size', type=float, default=0.1, help='Position size as a fraction of total balance (0.0-1.0)') parser.add_argument('--max-trades-per-day', type=int, default=5, help='Maximum number of trades per day') parser.add_argument('--trade-cooldown', type=int, default=60, help='Trade cooldown period in minutes') parser.add_argument('--config-file', type=str, default=None, help='Path to configuration file') args = parser.parse_args() # Load configuration from file if provided if args.config_file and os.path.exists(args.config_file): with open(args.config_file, 'r') as f: config = json.load(f) # Override config with command-line args for key, value in vars(args).items(): if key != 'config_file' and value is not None: config[key] = value else: # Use command-line args as config config = vars(args) # Initialize real-time charts and data interfaces try: from dataprovider_realtime import RealTimeChart # Create a real-time chart for each symbol charts = {} for symbol in config['symbols']: charts[symbol] = RealTimeChart(symbol=symbol) main_chart = charts[config['symbols'][0]] # Create a data interface for retrieving market data data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart) # Load trained model model_type = os.environ.get("NN_MODEL_TYPE", config['model_type']) model = load_model( model_type=model_type, input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV) output_size=config['output_size'] ) # Configure trading agent exchange_config = { "exchange": config['exchange'], "api_key": config['api_key'], "api_secret": config['api_secret'], "test_mode": config['test_mode'], "trade_symbols": config['symbols'], "position_size": config['position_size'], "max_trades_per_day": config['max_trades_per_day'], "trade_cooldown_minutes": config['trade_cooldown'] } # Initialize neural network orchestrator orchestrator = NeuralNetworkOrchestrator( model=model, data_interface=data_interface, chart=main_chart, symbols=config['symbols'], timeframes=config['timeframes'], window_size=config['window_size'], num_features=5, # OHLCV output_size=config['output_size'], exchange_config=exchange_config ) # Start data collection logger.info("Starting data collection threads...") for symbol in config['symbols']: charts[symbol].start() # Start neural network inference if os.environ.get("ENABLE_NN_MODELS", "0") == "1": logger.info("Starting neural network inference...") orchestrator.start_inference() else: logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.") # Start web servers for chart display logger.info("Starting web servers for chart display...") main_chart.start_server() logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.") # Keep the main thread alive try: while True: time.sleep(1) except KeyboardInterrupt: logger.info("Keyboard interrupt received. Shutting down...") # Stop all threads for symbol in config['symbols']: charts[symbol].stop() orchestrator.stop_inference() logger.info("Trading bot stopped.") except Exception as e: logger.error(f"Error in main function: {str(e)}", exc_info=True) sys.exit(1) if __name__ == "__main__": main() def get_state(self): """Create state representation for the agent with enhanced features""" # Ensure we have enough data if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0: # Return zeros if not enough data return np.zeros(STATE_SIZE) # Create a normalized state vector with recent price action and indicators state_components = [] # Safely get the latest price try: latest_price = self.features['price'][-1] except IndexError: # If we can't get the latest price, return zeros return np.zeros(STATE_SIZE) # Safely get price features try: # Price features (normalize recent prices by the latest price) price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0 state_components.append(price_features) except (IndexError, ZeroDivisionError): # If we can't get price features, use zeros state_components.append(np.zeros(10)) # Safely get volume features try: # Volume features (normalize by max volume) max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1 vol_features = np.array(self.features['volume'][-5:]) / max_vol state_components.append(vol_features) except (IndexError, ZeroDivisionError): # If we can't get volume features, use zeros state_components.append(np.zeros(5)) # Technical indicators rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1 state_components.append(rsi) # MACD (normalize) macd_vals = np.array(self.features['macd'][-3:]) macd_signal = np.array(self.features['macd_signal'][-3:]) macd_hist = np.array(self.features['macd_hist'][-3:]) macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5) macd_norm = macd_vals / macd_scale macd_signal_norm = macd_signal / macd_scale macd_hist_norm = macd_hist / macd_scale state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm]) # Bollinger position (where is price relative to bands) bb_upper = np.array(self.features['bollinger_upper'][-3:]) bb_lower = np.array(self.features['bollinger_lower'][-3:]) bb_mid = np.array(self.features['bollinger_mid'][-3:]) price = np.array(self.features['price'][-3:]) # Calculate position of price within Bollinger Bands (0 to 1) bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)] state_components.append(np.array(bb_pos)) # Stochastic oscillator state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0) state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0) # Add predicted prices (if available) if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Normalize predictions relative to current price pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0 state_components.append(pred_norm) else: # Add zeros if no predictions state_components.append(np.zeros(3)) # Add extrema signals (if available) if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0: # Get recent signals idx = len(self.optimal_signals) - 5 if idx < 0: idx = 0 recent_signals = self.optimal_signals[idx:idx+5] # Pad if needed if len(recent_signals) < 5: recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant') state_components.append(recent_signals) else: # Add zeros if no signals state_components.append(np.zeros(5)) # Position info position_info = np.zeros(5) if self.position == 'long': position_info[0] = 1.0 # Position is long position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss % position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance elif self.position == 'short': position_info[0] = -1.0 # Position is short position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss % position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance state_components.append(position_info) # NEW FEATURES START HERE # 1. Price momentum features (rate of change over different periods) if len(self.features['price']) >= 20: roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0 momentum_features = np.array([roc_5, roc_10, roc_20]) state_components.append(momentum_features) else: state_components.append(np.zeros(3)) # 2. Volatility features if len(self.features['price']) >= 20: # Calculate price returns returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] # Calculate volatility (standard deviation of returns) volatility = np.std(returns) # Calculate normalized high-low range high_low_range = np.mean([ (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] for i in range(max(0, len(self.data)-5), len(self.data)) ]) if len(self.data) > 0 else 0 # ATR normalized by price atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 volatility_features = np.array([volatility, high_low_range, atr_norm]) state_components.append(volatility_features) else: state_components.append(np.zeros(3)) # 3. Market regime features if len(self.features['price']) >= 50: # Trend strength (ADX-like measure) ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price trend_strength = abs(ema9 - ema21) / ema21 # Detect if in range or trending is_range_bound = 1.0 if self.is_uncertain_market() else 0.0 is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0 # Detect if near support/resistance near_support = 1.0 if self.is_near_support() else 0.0 near_resistance = 1.0 if self.is_near_resistance() else 0.0 market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance]) state_components.append(market_regime) else: state_components.append(np.zeros(5)) # 4. Trade history features if len(self.trades) > 0: # Recent win/loss ratio recent_trades = self.trades[-min(10, len(self.trades)):] win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades) # Average profit/loss avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0 avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0 # Normalize by balance avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0 avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0 # Last trade result last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0 trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl]) state_components.append(trade_history) else: state_components.append(np.zeros(4)) # Combine all features state = np.concatenate([comp.flatten() for comp in state_components]) # Replace any NaN or infinite values state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) # Ensure the state has the correct size if len(state) != STATE_SIZE: logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}") # Pad or truncate to match expected size if len(state) < STATE_SIZE: state = np.pad(state, (0, STATE_SIZE - len(state))) else: state = state[:STATE_SIZE] return state def get_expanded_state_size(self): """Calculate the size of the expanded state representation""" # Create a dummy state to get its size state = self.get_state() return len(state) async def expand_model_with_new_features(agent, env): """Expand the model to handle new features without retraining from scratch""" # Get the new state size new_state_size = env.get_expanded_state_size() # Only expand if the new state size is larger if new_state_size > agent.state_size: logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})") # Expand the model success = agent.expand_model( new_state_size=new_state_size, new_hidden_size=512, # Increase hidden size for more capacity new_lstm_layers=3, # More layers for deeper patterns new_attention_heads=8 # More attention heads for complex relationships ) if success: logger.info(f"Model successfully expanded to handle {new_state_size} features") return True else: logger.error("Failed to expand model") return False else: logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient") return True def calculate_reward(self, action): """Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals""" reward = 0 # Validate current price if self.current_price <= 0 or self.current_price > 1000000: # Reasonable price range logger.error(f"Invalid current price: {self.current_price}") return -10.0 # Strong penalty for invalid price # Validate position size if self.position_size <= 0 or self.position_size > 1000000: # Reasonable position size range logger.error(f"Invalid position size: {self.position_size}") return -10.0 # Strong penalty for invalid position size if action == 1: # BUY/LONG if self.position == 'flat': # Opening a long position self.position = 'long' self.entry_price = self.current_price self.position_size = self.calculate_position_size() # Use the adjusted risk parameters self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) # Check if this is an optimal buy point if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom # Check for volume spike if len(self.features['volume']) > 5: avg_volume = np.mean(self.features['volume'][-5:-1]) current_volume = self.features['volume'][-1] if current_volume > avg_volume * 1.5: reward += 2.0 # Bonus for entering during high volume # Check for price action signals if self.features['rsi'][-1] < 30: # Oversold condition reward += 1.5 # Bonus for buying at oversold levels # Check if we're buying in a clear uptrend (good) if self.is_uptrend(): reward += 1.0 # Bonus for buying in uptrend elif self.is_downtrend(): reward -= 0.25 # Reduced penalty for buying in downtrend else: reward += 0.2 # Small reward for opening a position logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif self.position == 'short': # Close short and open long pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Validate PnL values if abs(pnl_percent) > 100: # Max 100% loss/gain logger.error(f"Invalid PnL percentage: {pnl_percent}") pnl_percent = max(min(pnl_percent, 100), -100) pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance with validation if abs(pnl_dollar) > self.balance * 2: # Max 200% of balance logger.error(f"Invalid PnL dollar amount: {pnl_dollar}") pnl_dollar = max(min(pnl_dollar, self.balance * 2), -self.balance * 2) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = len(self.features['price']) - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'market_direction': self.get_market_direction() }) # Reward based on PnL with stronger penalties for losses if pnl_dollar > 0: reward += 1.0 + min(pnl_dollar / 10, 5.0) # Cap positive reward at 5.0 self.win_count += 1 else: # Stronger penalty for losses, scaled by the size of the loss but capped loss_penalty = min(1.0 + abs(pnl_dollar) / 5, 5.0) reward -= loss_penalty self.loss_count += 1 # Extra penalty for closing a losing trade too quickly if trade_duration < 5: reward -= 0.5 # Penalty for very short losing trades logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Now open long self.position = 'long' self.entry_price = self.current_price self.entry_index = len(self.features['price']) - 1 self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) # Check if this is an optimal buy point if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif action == 2: # SELL/SHORT if self.position == 'flat': # Opening a short position self.position = 'short' self.entry_price = self.current_price self.position_size = self.calculate_position_size() # Use the adjusted risk parameters self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 - self.take_profit_pct/100) # Check if this is an optimal sell point (top) current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: reward += 3.0 # Increased bonus for selling at a top # Check for volume spike if len(self.features['volume']) > 5: avg_volume = np.mean(self.features['volume'][-5:-1]) current_volume = self.features['volume'][-1] if current_volume > avg_volume * 1.5: reward += 2.0 # Bonus for entering during high volume # Check for price action signals if self.features['rsi'][-1] > 70: # Overbought condition reward += 1.5 # Bonus for selling at overbought levels # Check if we're selling in a clear downtrend (good) if self.is_downtrend(): reward += 1.0 # Bonus for selling in downtrend elif self.is_uptrend(): reward -= 0.25 # Reduced penalty for selling in uptrend else: reward += 0.2 # Small reward for opening a position logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif self.position == 'long': # Close long and open short pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Now open short self.position = 'short' self.entry_price = self.current_price self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 - self.take_profit_pct/100) # Check if this is an optimal sell point current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: reward += 2.0 # Bonus for selling at a top logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif action == 3: # CLOSE if self.position == 'long': # Close long position pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.entry_price = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': # Close short position pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.entry_price = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Add prediction accuracy component to reward if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Compare the first prediction with actual price if len(self.data) > 1: actual_price = self.data[-1]['close'] predicted_price = self.predicted_prices[0] prediction_error = abs(predicted_price - actual_price) / actual_price # Reward accurate predictions, penalize bad ones if prediction_error < 0.005: # Less than 0.5% error reward += 0.5 elif prediction_error > 0.02: # More than 2% error reward -= 0.5 return reward def is_downtrend(self): """Check if the market is in a downtrend""" if len(self.features['price']) < 20: return False # Use EMA to determine trend short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] # Downtrend if short EMA is below long EMA return short_ema < long_ema def is_uptrend(self): """Check if the market is in an uptrend""" if len(self.features['price']) < 20: return False # Use EMA to determine trend short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] # Uptrend if short EMA is above long EMA return short_ema > long_ema def get_market_direction(self): """Get the current market direction""" if self.is_uptrend(): return "uptrend" elif self.is_downtrend(): return "downtrend" else: return "sideways" def analyze_trades(self): """Analyze completed trades to identify patterns""" if not self.trades: return {} analysis = { 'total_trades': len(self.trades), 'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0), 'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0), 'avg_win': 0, 'avg_loss': 0, 'avg_duration': 0, 'uptrend_win_rate': 0, 'downtrend_win_rate': 0, 'sideways_win_rate': 0 } # Calculate averages wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0] losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0] durations = [t.get('duration', 0) for t in self.trades] analysis['avg_win'] = sum(wins) / len(wins) if wins else 0 analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0 analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0 # Calculate win rates by market direction for direction in ['uptrend', 'downtrend', 'sideways']: direction_trades = [t for t in self.trades if t.get('market_direction') == direction] if direction_trades: wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0) analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100 return analysis def initialize_price_predictor(self, device="cpu"): """Initialize the price prediction model""" self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) self.price_predictor.to(device) self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) self.predicted_prices = np.array([]) def train_price_predictor(self): """Train the price prediction model on recent data""" if len(self.features['price']) < 35: return 0.0 # Get price history price_history = self.features['price'] # Train the model loss = self.price_predictor.train_on_new_data( price_history, self.price_predictor_optimizer, epochs=5 ) return loss def update_price_predictions(self): """Update price predictions""" if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None: self.predicted_prices = np.array([]) return # Get price history price_history = self.features['price'] try: # Get predictions self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) except Exception as e: logger.warning(f"Error updating predictions: {e}") self.predicted_prices = np.array([]) def identify_optimal_trades(self): """Identify optimal entry and exit points based on local extrema""" if len(self.features['price']) < 20: return # Find local bottoms and tops bottoms, tops = find_local_extrema(self.features['price'], window=5) # Store optimal trade points self.optimal_bottoms = bottoms # Buy points self.optimal_tops = tops # Sell points # Create optimal trade signals self.optimal_signals = np.zeros(len(self.features['price'])) for i in bottoms: if 0 <= i < len(self.optimal_signals): # Ensure index is valid self.optimal_signals[i] = 1 # Buy signal for i in tops: if 0 <= i < len(self.optimal_signals): # Ensure index is valid self.optimal_signals[i] = -1 # Sell signal logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points") import os import time import json import numpy as np import pandas as pd from datetime import datetime import random import logging import asyncio import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from collections import deque, namedtuple from dotenv import load_dotenv import ccxt import websockets from torch.utils.tensorboard import SummaryWriter import torch.cuda.amp as amp # Add this import at the top from sklearn.preprocessing import MinMaxScaler import copy import argparse import traceback import io import matplotlib.dates as mdates from matplotlib.figure import Figure from PIL import Image import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime from dataprovider_realtime import BinanceWebSocket, BinanceHistoricalData from datetime import datetime as dt # Add Dash-related imports import dash from dash import html, dcc, callback_context from dash.dependencies import Input, Output, State import plotly.graph_objects as go from plotly.subplots import make_subplots from threading import Thread import socket # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()] ) logger = logging.getLogger("trading_bot") # Look for WebSocket specific logger websocket_logger = logging.getLogger('websocket') # or similar name websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO # Add this somewhere after the logger is defined class WebSocketFilter(logging.Filter): def filter(self, record): # Filter out DEBUG messages from WebSocket-related modules if record.levelno == logging.DEBUG and ('websocket' in record.name or 'protocol' in record.name or 'realtime' in record.name): return False return True logger.addFilter(WebSocketFilter()) # Load environment variables load_dotenv() MEXC_API_KEY = os.getenv('MEXC_API_KEY') MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY') # Constants INITIAL_BALANCE = 100 # USD MAX_LEVERAGE = 100 STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5% MEMORY_SIZE = 100000 BATCH_SIZE = 64 GAMMA = 0.99 # Discount factor EPSILON_START = 1.0 EPSILON_END = 0.05 EPSILON_DECAY = 10000 STATE_SIZE = 64 # Size of our state representation LEARNING_RATE = 1e-4 TARGET_UPDATE = 10 # Update target network every 10 episodes # Experience replay tuple Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done']) # Add this function near the top of the file, after the imports but before any classes def find_local_extrema(prices, window=5): """Find local minima (bottoms) and maxima (tops) in price data""" bottoms = [] tops = [] if len(prices) < window * 2 + 1: return bottoms, tops for i in range(window, len(prices) - window): # Check if this is a local minimum (bottom) if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \ all(prices[i] <= prices[i+j] for j in range(1, window+1)): bottoms.append(i) # Check if this is a local maximum (top) if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \ all(prices[i] >= prices[i+j] for j in range(1, window+1)): tops.append(i) return bottoms, tops class ReplayMemory: def __init__(self, capacity): self.memory = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.memory.append(Experience(state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class DQN(nn.Module): def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() self.state_size = state_size self.hidden_size = hidden_size self.lstm_layers = lstm_layers # Initial feature extraction self.fc1 = nn.Linear(state_size, hidden_size) # Use LayerNorm instead of BatchNorm for more stability with varying batch sizes self.ln1 = nn.LayerNorm(hidden_size) self.dropout1 = nn.Dropout(0.2) # LSTM layer for sequential data self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2) # Attention mechanism self.attention = nn.MultiheadAttention(hidden_size, attention_heads) # Output layers with increased capacity self.fc2 = nn.Linear(hidden_size, hidden_size) self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm self.dropout2 = nn.Dropout(0.2) self.fc3 = nn.Linear(hidden_size, hidden_size // 2) # Dueling DQN architecture self.value_stream = nn.Linear(hidden_size // 2, 1) self.advantage_stream = nn.Linear(hidden_size // 2, action_size) # Transformer encoder for more complex pattern recognition encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) def forward(self, x): batch_size = x.size(0) if x.dim() > 1 else 1 # Ensure input has correct shape if x.dim() == 1: x = x.unsqueeze(0) # Add batch dimension # Check if state size matches expected input size if x.size(1) != self.state_size: # Handle mismatched input by either truncating or padding if x.size(1) > self.state_size: x = x[:, :self.state_size] # Truncate else: # Pad with zeros padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device) x = torch.cat([x, padding], dim=1) # Initial feature extraction x = self.fc1(x) x = F.relu(self.ln1(x)) # LayerNorm works with any batch size x = self.dropout1(x) # Reshape for LSTM x_lstm = x.unsqueeze(1) if x.dim() == 2 else x # Process through LSTM lstm_out, _ = self.lstm(x_lstm) lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1] # Process through transformer for more complex patterns transformer_input = x.unsqueeze(1) if x.dim() == 2 else x transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1)) transformer_out = transformer_out.transpose(0, 1).mean(dim=1) # Combine LSTM and transformer outputs x = lstm_out + transformer_out # Final layers x = self.fc2(x) x = F.relu(self.ln2(x)) # LayerNorm works with any batch size x = self.dropout2(x) x = F.relu(self.fc3(x)) # Dueling architecture value = self.value_stream(x) advantages = self.advantage_stream(x) qvals = value + (advantages - advantages.mean(dim=1, keepdim=True)) return qvals class PricePredictionModel(nn.Module): def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): super(PricePredictionModel, self).__init__() self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2) self.fc = nn.Linear(hidden_size, output_size) self.scaler = MinMaxScaler(feature_range=(0, 1)) self.is_fitted = False def forward(self, x): # x shape: [batch_size, seq_len, 1] lstm_out, _ = self.lstm(x) # Use the last time step output predictions = self.fc(lstm_out[:, -1, :]) return predictions def preprocess(self, data): # Reshape data for scaler data_reshaped = np.array(data).reshape(-1, 1) # Fit scaler if not already fitted if not self.is_fitted: self.scaler.fit(data_reshaped) self.is_fitted = True # Transform data scaled_data = self.scaler.transform(data_reshaped) return scaled_data def postprocess(self, scaled_predictions): # Inverse transform to get actual price values return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten() def predict_next_candles(self, price_history, num_candles=5): if len(price_history) < 30: # Need enough history return np.zeros(num_candles) # Preprocess data scaled_data = self.preprocess(price_history) # Create sequence sequence = scaled_data[-30:].reshape(1, 30, 1) sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device) # Get predictions with torch.no_grad(): scaled_predictions = self(sequence_tensor).cpu().numpy()[0] # Postprocess predictions predictions = self.postprocess(scaled_predictions) return predictions def train_on_new_data(self, price_history, optimizer, epochs=10): if len(price_history) < 35: # Need enough history for training return 0.0 # Preprocess data scaled_data = self.preprocess(price_history) # Create sequences and targets sequences = [] targets = [] for i in range(len(scaled_data) - 35): # Sequence: 30 time steps seq = scaled_data[i:i+30] # Target: next 5 time steps target = scaled_data[i+30:i+35].flatten() sequences.append(seq) targets.append(target) if not sequences: # If no sequences were created return 0.0 # Convert to tensors sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device) targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device) # Training loop total_loss = 0 for _ in range(epochs): # Forward pass predictions = self(sequences_tensor) # Calculate loss loss = F.mse_loss(predictions, targets_tensor) # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / epochs class TradingEnvironment: """Trading environment for reinforcement learning with enhanced features""" def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True): """Initialize trading environment Args: initial_balance: Starting account balance window_size: Number of candles in the state window demo: Whether to run in demo mode """ self.initial_balance = initial_balance self.balance = initial_balance self.window_size = window_size self.demo = demo self.data = [] self.position = 'flat' # 'flat', 'long', or 'short' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 self.trades = [] self.win_count = 0 self.loss_count = 0 self.total_pnl = 0.0 self.episode_pnl = 0.0 self.peak_balance = initial_balance self.max_drawdown = 0.0 self.current_step = 0 self.current_price = 0 # Risk management parameters (adjusted for more aggressive trading) self.stop_loss_pct = STOP_LOSS_PERCENT * 0.8 # Tighter stop loss (80% of original) self.take_profit_pct = TAKE_PROFIT_PERCENT * 1.5 # Higher take profit (150% of original) self.trailing_stop_activated = False self.trailing_stop_distance = 0 self.max_position_size_pct = 0.8 # Use up to 80% of balance for position size # For tracking signals for visualization self.trade_signals = [] # Initialize features self.features = { 'price': [], 'volume': [], 'rsi': [], 'macd': [], 'macd_signal': [], 'macd_hist': [], 'bollinger_upper': [], 'bollinger_mid': [], 'bollinger_lower': [], 'stoch_k': [], 'stoch_d': [], 'ema_9': [], 'ema_21': [], 'atr': [] } # Initialize price predictor self.price_predictor = None self.predicted_prices = np.array([]) # Initialize optimal trade tracking self.optimal_bottoms = [] self.optimal_tops = [] self.optimal_signals = np.array([]) # Add these new attributes self.leverage = MAX_LEVERAGE self.futures_symbol = "ETH_USDT" # Example futures symbol self.position_mode = "hedge" # For simultaneous long/short positions self.margin_mode = "cross" # Cross margin mode def reset(self): """Reset the environment to initial state""" self.balance = self.initial_balance self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 self.trades = [] self.win_count = 0 self.loss_count = 0 self.episode_pnl = 0.0 self.peak_balance = self.initial_balance self.max_drawdown = 0.0 self.current_step = 0 # Keep data but reset current position if len(self.data) > self.window_size: self.current_step = self.window_size self.current_price = self.data[self.current_step]['close'] # Reset trade signals self.trade_signals = [] return self.get_state() def add_data(self, candle): """Add a new candle to the data""" self.data.append(candle) self._update_features() self.current_price = candle['close'] def _initialize_features(self): """Initialize technical indicators and features""" if len(self.data) < 30: return # Convert data to pandas DataFrame for easier calculation df = pd.DataFrame(self.data) # Basic price and volume self.features['price'] = df['close'].values self.features['volume'] = df['volume'].values # Calculate RSI (14 periods) delta = df['close'].diff() gain = delta.where(delta > 0, 0).rolling(window=14).mean() loss = -delta.where(delta < 0, 0).rolling(window=14).mean() rs = gain / loss self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values # Calculate MACD ema12 = df['close'].ewm(span=12, adjust=False).mean() ema26 = df['close'].ewm(span=26, adjust=False).mean() macd = ema12 - ema26 signal = macd.ewm(span=9, adjust=False).mean() self.features['macd'] = macd.values self.features['macd_signal'] = signal.values self.features['macd_hist'] = (macd - signal).values # Calculate Bollinger Bands sma20 = df['close'].rolling(window=20).mean() std20 = df['close'].rolling(window=20).std() self.features['bollinger_upper'] = (sma20 + 2 * std20).values self.features['bollinger_mid'] = sma20.values self.features['bollinger_lower'] = (sma20 - 2 * std20).values # Calculate Stochastic Oscillator low_14 = df['low'].rolling(window=14).min() high_14 = df['high'].rolling(window=14).max() k = 100 * ((df['close'] - low_14) / (high_14 - low_14)) self.features['stoch_k'] = k.values self.features['stoch_d'] = k.rolling(window=3).mean().values # Calculate EMAs self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values # Calculate ATR high_low = df['high'] - df['low'] high_close = (df['high'] - df['close'].shift()).abs() low_close = (df['low'] - df['close'].shift()).abs() tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values def _update_features(self): """Update technical indicators with new data""" self._initialize_features() # Recalculate all features async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch initial historical data for the environment""" try: logger.info(f"Fetching initial data for {symbol}") # Use the refactored fetch method data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) # Update environment with fetched data if data: self.data = data self._initialize_features() logger.info(f"Initialized environment with {len(data)} candles") else: logger.warning("No initial data received") return len(data) > 0 except Exception as e: logger.error(f"Error fetching initial data: {e}") return False def step(self, action): """Take an action in the environment and return the next state, reward, and done flag""" # Check if we have enough data if self.current_step >= len(self.data) - 1: # We've reached the end of data done = True next_state = self.get_state() info = { 'action': 'none', 'price': self.current_price, 'balance': self.balance, 'position': self.position, 'pnl': self.total_pnl } return next_state, 0, done, info # Adapt trading parameters to current market conditions self.adapt_trading_parameters_to_market() # Store current price before taking action self.current_price = self.data[self.current_step]['close'] # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) reward = self.calculate_reward(action) # Record trade signal for visualization if action > 0: # If not HOLD signal_type = None if action == 1: # BUY/LONG signal_type = 'buy' elif action == 2: # SELL/SHORT signal_type = 'sell' elif action == 3: # CLOSE if self.position == 'long': signal_type = 'close_long' elif self.position == 'short': signal_type = 'close_short' if signal_type: self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.current_price, 'type': signal_type, 'balance': self.balance, 'pnl': self.total_pnl }) # Check for stop loss / take profit hits self.check_sl_tp() # Move to next step self.current_step += 1 done = self.current_step >= len(self.data) - 1 # Get new state next_state = self.get_state() # Create info dictionary info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': self.current_price, 'balance': self.balance, 'position': self.position, 'pnl': self.total_pnl } return next_state, reward, done, info def check_sl_tp(self): """Check if stop loss or take profit has been hit""" if self.position == 'flat': return if self.position == 'long': # Check stop loss if self.current_price <= self.stop_loss: # Stop loss hit pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'stop_loss' }) # Update win/loss count self.loss_count += 1 logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.stop_loss, 'type': 'stop_loss_long', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Check take profit elif self.current_price >= self.take_profit: # Take profit hit pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'take_profit' }) # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.take_profit, 'type': 'take_profit_long', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': # Check stop loss if self.current_price >= self.stop_loss: # Stop loss hit pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'stop_loss' }) # Update win/loss count self.loss_count += 1 logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.stop_loss, 'type': 'stop_loss_short', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Check take profit elif self.current_price <= self.take_profit: # Take profit hit pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'take_profit' }) # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.take_profit, 'type': 'take_profit_short', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 def get_state(self): """Create state representation for the agent with enhanced features""" # Ensure we have enough data if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0: # Return zeros if not enough data return np.zeros(STATE_SIZE) # Create a normalized state vector with recent price action and indicators state_components = [] # Safely get the latest price try: latest_price = self.features['price'][-1] except IndexError: # If we can't get the latest price, return zeros return np.zeros(STATE_SIZE) # Safely get price features try: # Price features (normalize recent prices by the latest price) price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0 state_components.append(price_features) except (IndexError, ZeroDivisionError): # If we can't get price features, use zeros state_components.append(np.zeros(10)) # Safely get volume features try: # Volume features (normalize by max volume) max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1 vol_features = np.array(self.features['volume'][-5:]) / max_vol state_components.append(vol_features) except (IndexError, ZeroDivisionError): # If we can't get volume features, use zeros state_components.append(np.zeros(5)) # Technical indicators rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1 state_components.append(rsi) # MACD (normalize) macd_vals = np.array(self.features['macd'][-3:]) macd_signal = np.array(self.features['macd_signal'][-3:]) macd_hist = np.array(self.features['macd_hist'][-3:]) macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5) macd_norm = macd_vals / macd_scale macd_signal_norm = macd_signal / macd_scale macd_hist_norm = macd_hist / macd_scale state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm]) # Bollinger position (where is price relative to bands) bb_upper = np.array(self.features['bollinger_upper'][-3:]) bb_lower = np.array(self.features['bollinger_lower'][-3:]) bb_mid = np.array(self.features['bollinger_mid'][-3:]) price = np.array(self.features['price'][-3:]) # Calculate position of price within Bollinger Bands (0 to 1) bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)] state_components.append(np.array(bb_pos)) # Stochastic oscillator state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0) state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0) # Add predicted prices (if available) if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Normalize predictions relative to current price pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0 state_components.append(pred_norm) else: # Add zeros if no predictions state_components.append(np.zeros(3)) # Add extrema signals (if available) if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0: # Get recent signals idx = len(self.optimal_signals) - 5 if idx < 0: idx = 0 recent_signals = self.optimal_signals[idx:idx+5] # Pad if needed if len(recent_signals) < 5: recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant') state_components.append(recent_signals) else: # Add zeros if no signals state_components.append(np.zeros(5)) # Position info position_info = np.zeros(5) if self.position == 'long': position_info[0] = 1.0 # Position is long position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss % position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance elif self.position == 'short': position_info[0] = -1.0 # Position is short position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss % position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance state_components.append(position_info) # NEW FEATURES START HERE # 1. Price momentum features (rate of change over different periods) if len(self.features['price']) >= 20: roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0 momentum_features = np.array([roc_5, roc_10, roc_20]) state_components.append(momentum_features) else: state_components.append(np.zeros(3)) # 2. Volatility features if len(self.features['price']) >= 20: # Calculate price returns returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] # Calculate volatility (standard deviation of returns) volatility = np.std(returns) # Calculate normalized high-low range high_low_range = np.mean([ (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] for i in range(max(0, len(self.data)-5), len(self.data)) ]) if len(self.data) > 0 else 0 # ATR normalized by price atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 volatility_features = np.array([volatility, high_low_range, atr_norm]) state_components.append(volatility_features) else: state_components.append(np.zeros(3)) # 3. Market regime features if len(self.features['price']) >= 50: # Trend strength (ADX-like measure) ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price trend_strength = abs(ema9 - ema21) / ema21 # Detect if in range or trending is_range_bound = 1.0 if self.is_uncertain_market() else 0.0 is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0 # Detect if near support/resistance near_support = 1.0 if self.is_near_support() else 0.0 near_resistance = 1.0 if self.is_near_resistance() else 0.0 market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance]) state_components.append(market_regime) else: state_components.append(np.zeros(5)) # 4. Trade history features if len(self.trades) > 0: # Recent win/loss ratio recent_trades = self.trades[-min(10, len(self.trades)):] win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades) # Average profit/loss avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0 avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0 # Normalize by balance avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0 avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0 # Last trade result last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0 trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl]) state_components.append(trade_history) else: state_components.append(np.zeros(4)) # Combine all features state = np.concatenate([comp.flatten() for comp in state_components]) # Replace any NaN or infinite values state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) # Ensure the state has the correct size if len(state) != STATE_SIZE: logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}") # Pad or truncate to match expected size if len(state) < STATE_SIZE: state = np.pad(state, (0, STATE_SIZE - len(state))) else: state = state[:STATE_SIZE] return state def get_expanded_state_size(self): """Calculate the size of the expanded state representation""" # Create a dummy state to get its size state = self.get_state() return len(state) async def expand_model_with_new_features(agent, env): """Expand the model to handle new features without retraining from scratch""" # Get the new state size new_state_size = env.get_expanded_state_size() # Only expand if the new state size is larger if new_state_size > agent.state_size: logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})") # Expand the model success = agent.expand_model( new_state_size=new_state_size, new_hidden_size=512, # Increase hidden size for more capacity new_lstm_layers=3, # More layers for deeper patterns new_attention_heads=8 # More attention heads for complex relationships ) if success: logger.info(f"Model successfully expanded to handle {new_state_size} features") return True else: logger.error("Failed to expand model") return False else: logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient") return True def calculate_reward(self, action): """Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals""" reward = 0 # Validate current price if self.current_price <= 0 or self.current_price > 1000000: # Reasonable price range logger.error(f"Invalid current price: {self.current_price}") return -10.0 # Strong penalty for invalid price # Validate position size if self.position_size <= 0 or self.position_size > 1000000: # Reasonable position size range logger.error(f"Invalid position size: {self.position_size}") return -10.0 # Strong penalty for invalid position size if action == 1: # BUY/LONG if self.position == 'flat': # Opening a long position self.position = 'long' self.entry_price = self.current_price self.position_size = self.calculate_position_size() # Use the adjusted risk parameters self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) # Check if this is an optimal buy point if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom # Check for volume spike if len(self.features['volume']) > 5: avg_volume = np.mean(self.features['volume'][-5:-1]) current_volume = self.features['volume'][-1] if current_volume > avg_volume * 1.5: reward += 2.0 # Bonus for entering during high volume # Check for price action signals if self.features['rsi'][-1] < 30: # Oversold condition reward += 1.5 # Bonus for buying at oversold levels # Check if we're buying in a clear uptrend (good) if self.is_uptrend(): reward += 1.0 # Bonus for buying in uptrend elif self.is_downtrend(): reward -= 0.25 # Reduced penalty for buying in downtrend else: reward += 0.2 # Small reward for opening a position logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif self.position == 'short': # Close short and open long pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Validate PnL values if abs(pnl_percent) > 100: # Max 100% loss/gain logger.error(f"Invalid PnL percentage: {pnl_percent}") pnl_percent = max(min(pnl_percent, 100), -100) pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance with validation if abs(pnl_dollar) > self.balance * 2: # Max 200% of balance logger.error(f"Invalid PnL dollar amount: {pnl_dollar}") pnl_dollar = max(min(pnl_dollar, self.balance * 2), -self.balance * 2) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = len(self.features['price']) - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'market_direction': self.get_market_direction() }) # Reward based on PnL with stronger penalties for losses if pnl_dollar > 0: reward += 1.0 + min(pnl_dollar / 10, 5.0) # Cap positive reward at 5.0 self.win_count += 1 else: # Stronger penalty for losses, scaled by the size of the loss but capped loss_penalty = min(1.0 + abs(pnl_dollar) / 5, 5.0) reward -= loss_penalty self.loss_count += 1 # Extra penalty for closing a losing trade too quickly if trade_duration < 5: reward -= 0.5 # Penalty for very short losing trades logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Now open long self.position = 'long' self.entry_price = self.current_price self.entry_index = len(self.features['price']) - 1 self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) # Check if this is an optimal buy point if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif action == 2: # SELL/SHORT if self.position == 'flat': # Opening a short position self.position = 'short' self.entry_price = self.current_price self.position_size = self.calculate_position_size() # Use the adjusted risk parameters self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 - self.take_profit_pct/100) # Check if this is an optimal sell point (top) current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: reward += 3.0 # Increased bonus for selling at a top # Check for volume spike if len(self.features['volume']) > 5: avg_volume = np.mean(self.features['volume'][-5:-1]) current_volume = self.features['volume'][-1] if current_volume > avg_volume * 1.5: reward += 2.0 # Bonus for entering during high volume # Check for price action signals if self.features['rsi'][-1] > 70: # Overbought condition reward += 1.5 # Bonus for selling at overbought levels # Check if we're selling in a clear downtrend (good) if self.is_downtrend(): reward += 1.0 # Bonus for selling in downtrend elif self.is_uptrend(): reward -= 0.25 # Reduced penalty for selling in uptrend else: reward += 0.2 # Small reward for opening a position logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif self.position == 'long': # Close long and open short pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Now open short self.position = 'short' self.entry_price = self.current_price self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100) self.take_profit = self.entry_price * (1 - self.take_profit_pct/100) # Check if this is an optimal sell point current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: reward += 2.0 # Bonus for selling at a top logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif action == 3: # CLOSE if self.position == 'long': # Close long position pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.entry_price = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': # Close short position pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.entry_price = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Add prediction accuracy component to reward if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Compare the first prediction with actual price if len(self.data) > 1: actual_price = self.data[-1]['close'] predicted_price = self.predicted_prices[0] prediction_error = abs(predicted_price - actual_price) / actual_price # Reward accurate predictions, penalize bad ones if prediction_error < 0.005: # Less than 0.5% error reward += 0.5 elif prediction_error > 0.02: # More than 2% error reward -= 0.5 return reward def is_downtrend(self): """Check if the market is in a downtrend""" if len(self.features['price']) < 20: return False # Use EMA to determine trend short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] # Downtrend if short EMA is below long EMA return short_ema < long_ema def is_uptrend(self): """Check if the market is in an uptrend""" if len(self.features['price']) < 20: return False # Use EMA to determine trend short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] # Uptrend if short EMA is above long EMA return short_ema > long_ema def get_market_direction(self): """Get the current market direction""" if self.is_uptrend(): return "uptrend" elif self.is_downtrend(): return "downtrend" else: return "sideways" def analyze_trades(self): """Analyze completed trades to identify patterns""" if not self.trades: return {} analysis = { 'total_trades': len(self.trades), 'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0), 'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0), 'avg_win': 0, 'avg_loss': 0, 'avg_duration': 0, 'uptrend_win_rate': 0, 'downtrend_win_rate': 0, 'sideways_win_rate': 0 } # Calculate averages wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0] losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0] durations = [t.get('duration', 0) for t in self.trades] analysis['avg_win'] = sum(wins) / len(wins) if wins else 0 analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0 analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0 # Calculate win rates by market direction for direction in ['uptrend', 'downtrend', 'sideways']: direction_trades = [t for t in self.trades if t.get('market_direction') == direction] if direction_trades: wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0) analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100 return analysis def initialize_price_predictor(self, device="cpu"): """Initialize the price prediction model""" self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) self.price_predictor.to(device) self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) self.predicted_prices = np.array([]) def train_price_predictor(self): """Train the price prediction model on recent data""" if len(self.features['price']) < 35: return 0.0 # Get price history price_history = self.features['price'] # Train the model loss = self.price_predictor.train_on_new_data( price_history, self.price_predictor_optimizer, epochs=5 ) return loss def update_price_predictions(self): """Update price predictions""" if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None: self.predicted_prices = np.array([]) return # Get price history price_history = self.features['price'] try: # Get predictions self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) except Exception as e: logger.warning(f"Error updating predictions: {e}") self.predicted_prices = np.array([]) def identify_optimal_trades(self): """Identify optimal entry and exit points based on local extrema""" if len(self.features['price']) < 20: return # Find local bottoms and tops bottoms, tops = find_local_extrema(self.features['price'], window=5) # Store optimal trade points self.optimal_bottoms = bottoms # Buy points self.optimal_tops = tops # Sell points # Create optimal trade signals self.optimal_signals = np.zeros(len(self.features['price'])) for i in bottoms: if 0 <= i < len(self.optimal_signals): # Ensure index is valid self.optimal_signals[i] = 1 # Buy signal for i in tops: if 0 <= i < len(self.optimal_signals): # Ensure index is valid self.optimal_signals[i] = -1 # Sell signal logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points") def calculate_position_size(self): """Calculate position size based on current balance and risk parameters Returns: float: Position size in quote currency """ # More aggressive position sizing risk_amount = self.balance * (self.max_position_size_pct * random.uniform(0.7, 1.0)) # In futures trading, adjust for leverage if hasattr(self, 'leverage') and self.leverage > 1: risk_amount = min(risk_amount * self.leverage, self.balance * 10) # Limit max risk return risk_amount def calculate_fees(self, position_size): """Calculate trading fees for a given position size""" # Typical fee rate for crypto exchanges (0.1%) fee_rate = 0.001 # Calculate fee fee = position_size * fee_rate return fee def is_uncertain_market(self): """Check if the market is in an uncertain/sideways state""" if len(self.features['price']) < 20: return True # Check if price is within a narrow range recent_prices = self.features['price'][-20:] price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices) # Check if EMAs are close to each other if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] ema_diff = abs(short_ema - long_ema) / long_ema # Return True if price range is small and EMAs are close return price_range < 0.02 and ema_diff < 0.005 return price_range < 0.015 # Very narrow range def is_near_support(self): """Check if current price is near a support level""" if not hasattr(self, 'features') or len(self.features['price']) < 30: return False # Find recent lows prices = self.features['price'][-30:] lows = [] for i in range(1, len(prices)-1): if prices[i] < prices[i-1] and prices[i] < prices[i+1]: lows.append(prices[i]) if not lows: return False # Check if current price is near any of these lows current_price = self.current_price for low in lows: if abs(current_price - low) / low < 0.01: # Within 1% of a recent low return True return False def is_near_resistance(self): """Check if current price is near a resistance level""" if not hasattr(self, 'features') or len(self.features['price']) < 30: return False # Find recent highs prices = self.features['price'][-30:] highs = [] for i in range(1, len(prices)-1): if prices[i] > prices[i-1] and prices[i] > prices[i+1]: highs.append(prices[i]) if not highs: return False # Check if current price is near any of these highs current_price = self.current_price for high in highs: if abs(current_price - high) / high < 0.01: # Within 1% of a recent high return True return False def is_market_turning(self): """Check if the market is potentially changing direction""" if len(self.features['price']) < 20: return False # Check for divergence between price and momentum indicators if len(self.features['rsi']) > 5: # Price making higher highs but RSI making lower highs (bearish divergence) price_trend = self.features['price'][-1] > self.features['price'][-5] rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5] if price_trend != rsi_trend: return True # Check for EMA crossover if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1: short_ema_prev = self.features['ema_9'][-2] long_ema_prev = self.features['ema_21'][-2] short_ema_curr = self.features['ema_9'][-1] long_ema_curr = self.features['ema_21'][-1] # Check if EMAs just crossed if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \ (short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr): return True return False def is_market_against_position(self, position_type): """Check if market conditions have turned against the current position""" if position_type == 'long': # For long positions, check if market has turned bearish return self.is_downtrend() and not self.is_near_support() elif position_type == 'short': # For short positions, check if market has turned bullish return self.is_uptrend() and not self.is_near_resistance() return False def is_near_optimal_exit(self, position_type): """Check if current price is near an optimal exit point for the position""" current_idx = len(self.features['price']) - 1 if position_type == 'long' and hasattr(self, 'optimal_tops'): # For long positions, optimal exit is near tops for top_idx in self.optimal_tops: if abs(current_idx - top_idx) < 3: # Within 3 candles of a top return True elif position_type == 'short' and hasattr(self, 'optimal_bottoms'): # For short positions, optimal exit is near bottoms for bottom_idx in self.optimal_bottoms: if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom return True return False def calculate_future_profit_potential(self, position_type, lookahead=20): """ Calculate potential profit if position is held for a certain period This is used for retrospective backtesting rewards Args: position_type: 'long' or 'short' lookahead: Number of candles to look ahead Returns: Potential profit percentage """ if len(self.data) <= 1 or self.current_step >= len(self.data): return 0 # Get current price current_price = self.current_price # Get future prices (if available in historical data) future_prices = [] current_idx = self.current_step # Safely get future prices for i in range(1, min(lookahead + 1, len(self.data) - current_idx)): if current_idx + i < len(self.data): future_prices.append(self.data[current_idx + i]['close']) if not future_prices: return 0 # Calculate potential profit if position_type == 'long': # For long positions, find the maximum price in the future max_future_price = max(future_prices) potential_profit = (max_future_price - current_price) / current_price * 100 else: # short # For short positions, find the minimum price in the future min_future_price = min(future_prices) potential_profit = (current_price - min_future_price) / current_price * 100 return potential_profit async def initialize_futures(self, exchange): """Initialize futures trading parameters""" if not self.demo: try: # Set up futures trading parameters await exchange.set_position_mode(True) # Hedge mode await exchange.set_margin_mode("cross", symbol=self.futures_symbol) await exchange.set_leverage(self.leverage, symbol=self.futures_symbol) logger.info(f"Futures initialized with {self.leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True async def execute_real_trade(self, exchange, action, current_price): """Execute real futures trade on MEXC""" try: position_size = self.calculate_position_size() if action == 1: # Open long order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='buy', amount=position_size, params={'positionSide': 'LONG'} ) logger.info(f"Opened LONG position: {order}") elif action == 2: # Open short order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='sell', amount=position_size, params={'positionSide': 'SHORT'} ) logger.info(f"Opened SHORT position: {order}") elif action == 3: # Close position position_side = 'LONG' if self.position == 'long' else 'SHORT' order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='sell' if position_side == 'LONG' else 'buy', amount=self.position_size, params={'positionSide': position_side} ) logger.info(f"Closed {position_side} position: {order}") return order except Exception as e: logger.error(f"Trade execution failed: {e}") return None def is_volatile_market(self): """Detect if the market is currently in a volatile state with significant price movements Returns: bool: True if market is volatile, False otherwise """ if len(self.features['price']) < 20: return False # Calculate recent price volatility recent_prices = self.features['price'][-20:] returns = np.diff(recent_prices) / recent_prices[:-1] volatility = np.std(returns) * 100 # Convert to percentage # Calculate volume increase recent_volumes = self.features['volume'][-10:] avg_volume_prev = np.mean(self.features['volume'][-20:-10]) avg_volume_recent = np.mean(recent_volumes) volume_increase = avg_volume_recent / avg_volume_prev if avg_volume_prev > 0 else 1.0 # Calculate ATR if available atr_high = False if len(self.features['atr']) > 5: recent_atr = self.features['atr'][-1] avg_atr = np.mean(self.features['atr'][-20:-1]) atr_ratio = recent_atr / avg_atr if avg_atr > 0 else 1.0 atr_high = atr_ratio > 1.5 # Check if price moved significantly in either direction recently price_range_percent = (max(recent_prices) - min(recent_prices)) / min(recent_prices) * 100 # Market is volatile if any of these conditions are met volatile = ( volatility > 0.5 or # High standard deviation of returns volume_increase > 1.8 or # Volume spike price_range_percent > 1.5 or # Large price range atr_high # High ATR relative to average ) if volatile: logger.info(f"Volatile market detected - Volatility: {volatility:.2f}%, Volume increase: {volume_increase:.2f}x, Price range: {price_range_percent:.2f}%") return volatile def adapt_trading_parameters_to_market(self): """Dynamically adjust trading parameters based on market conditions Returns: None """ # Check market conditions is_volatile = self.is_volatile_market() is_trending_up = self.is_uptrend() is_trending_down = self.is_downtrend() # Base parameters base_stop_loss = STOP_LOSS_PERCENT base_take_profit = TAKE_PROFIT_PERCENT base_position_size = 0.5 # 50% of max # Adjust based on market conditions if is_volatile: # In volatile markets, use tighter stops but higher take profits self.stop_loss_pct = base_stop_loss * 0.7 # Tighter stop self.take_profit_pct = base_take_profit * 1.8 # Higher target self.max_position_size_pct = base_position_size * 1.3 # More aggressive sizing elif is_trending_up: # In uptrends, use looser stops for longs, tighter for shorts if self.position == 'long' or self.position == 'flat': self.stop_loss_pct = base_stop_loss * 0.9 self.take_profit_pct = base_take_profit * 1.6 self.max_position_size_pct = base_position_size * 1.2 else: # More conservative for shorts in uptrend self.stop_loss_pct = base_stop_loss * 0.7 self.take_profit_pct = base_take_profit * 1.2 self.max_position_size_pct = base_position_size * 0.8 elif is_trending_down: # In downtrends, use looser stops for shorts, tighter for longs if self.position == 'short' or self.position == 'flat': self.stop_loss_pct = base_stop_loss * 0.9 self.take_profit_pct = base_take_profit * 1.6 self.max_position_size_pct = base_position_size * 1.2 else: # More conservative for longs in downtrend self.stop_loss_pct = base_stop_loss * 0.7 self.take_profit_pct = base_take_profit * 1.2 self.max_position_size_pct = base_position_size * 0.8 else: # In sideways/uncertain markets, be more balanced self.stop_loss_pct = base_stop_loss * 0.8 self.take_profit_pct = base_take_profit * 1.3 self.max_position_size_pct = base_position_size # Log the adaptation logger.debug(f"Adapted trading parameters - Stop loss: {self.stop_loss_pct:.2f}%, Take profit: {self.take_profit_pct:.2f}%, Max position size: {self.max_position_size_pct*100:.1f}%") # Ensure GPU usage if available def get_device(): """Get the best available device (CUDA GPU or CPU)""" if torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") # Set up for mixed precision training torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") logger.info("GPU not available, using CPU") return device # Update Agent class to use GPU properly class Agent: def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None): """Initialize Agent with architecture parameters stored as attributes""" self.state_size = state_size self.action_size = action_size self.hidden_size = hidden_size # Store hidden_size as an instance attribute self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute self.attention_heads = attention_heads # Store attention_heads as an instance attribute # Set device self.device = device if device is not None else get_device() # Initialize networks self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) # Initialize optimizer self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) # Initialize replay memory self.memory = ReplayMemory(MEMORY_SIZE) # Initialize exploration parameters self.epsilon = EPSILON_START self.epsilon_decay = EPSILON_DECAY self.epsilon_min = EPSILON_END # Initialize step counter self.steps_done = 0 # Initialize TensorBoard writer self.writer = None # Initialize GradScaler for mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None # Rest of the initialization code... def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): """Expand the model to handle more features or increase capacity""" logger.info(f"Expanding model: {self.state_size} → {new_state_size}, " f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}") # Save old weights old_state_dict = self.policy_net.state_dict() # Create new larger networks new_policy_net = DQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) new_target_net = DQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) # Transfer weights for common layers new_state_dict = new_policy_net.state_dict() for name, param in old_state_dict.items(): if name in new_state_dict: # If shapes match, copy directly if new_state_dict[name].shape == param.shape: new_state_dict[name] = param # For first layer, copy weights for the original input dimensions elif name == "fc1.weight": new_state_dict[name][:, :self.state_size] = param # For other layers, initialize with a strategy that preserves scale else: logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}") # Load transferred weights new_policy_net.load_state_dict(new_state_dict) new_target_net.load_state_dict(new_state_dict) # Replace networks self.policy_net = new_policy_net self.target_net = new_target_net self.target_net.eval() # Update optimizer self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) # Update state size self.state_size = new_state_size # Print new model size total_params = sum(p.numel() for p in self.policy_net.parameters()) logger.info(f"New model size: {total_params:,} parameters") return True def select_action(self, state, training=True): sample = random.random() if training: # More aggressive epsilon decay for faster exploitation self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \ np.exp(-1.5 * self.steps_done / EPSILON_DECAY) # Increased decay factor self.steps_done += 1 # Lower threshold for exploration, especially in live trading if not training: # In live trading, be much more aggressive with exploitation self.epsilon = max(EPSILON_END, self.epsilon * 0.95) if sample > self.epsilon or not training: with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) action_values = self.policy_net(state_tensor) # Add temperature-based sampling for more aggressive actions # when the model is confident (higher action differences) if not training: # More aggressive in live trading values = action_values.cpu().numpy() max_value = np.max(values) value_diff = max_value - np.mean(values) # If there's a clear best action, always take it if value_diff > 0.5: return action_values.max(1)[1].item() return action_values.max(1)[1].item() else: return random.randrange(self.action_size) def learn(self): """Learn from a batch of experiences""" if len(self.memory) < BATCH_SIZE: return None try: # Sample a batch of experiences experiences = self.memory.sample(BATCH_SIZE) # Convert experiences to tensors states = torch.FloatTensor([e.state for e in experiences]).to(self.device) actions = torch.LongTensor([e.action for e in experiences]).to(self.device) rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device) next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device) dones = torch.FloatTensor([e.done for e in experiences]).to(self.device) # Use mixed precision for forward/backward passes if self.device.type == "cuda" and self.scaler is not None: with torch.amp.autocast('cuda'): # Compute Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) # Compute next Q values with target network with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) # Reshape target values to match current_q_values target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) # Backward pass with mixed precision self.optimizer.zero_grad() self.scaler.scale(loss).backward() # Gradient clipping to prevent exploding gradients self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.scaler.step(self.optimizer) self.scaler.update() else: # Standard precision for CPU # Compute Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) # Compute next Q values with target network with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) # Reshape target values to match current_q_values target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping to prevent exploding gradients torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() # Update steps done self.steps_done += 1 # Update target network if self.steps_done % TARGET_UPDATE == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) return loss.item() except Exception as e: logger.error(f"Error during learning: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path="models/trading_agent_best_pnl.pt"): """Save the model in a format compatible with PyTorch 2.6+""" try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(path), exist_ok=True) # Ensure architecture parameters are set if not hasattr(self, 'hidden_size'): self.hidden_size = 256 # Default value logger.warning("Setting default hidden_size=256 for saving") if not hasattr(self, 'lstm_layers'): self.lstm_layers = 2 # Default value logger.warning("Setting default lstm_layers=2 for saving") if not hasattr(self, 'attention_heads'): self.attention_heads = 4 # Default value logger.warning("Setting default attention_heads=4 for saving") # Save model state checkpoint = { 'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epsilon': self.epsilon, 'state_size': self.state_size, 'action_size': self.action_size, 'hidden_size': self.hidden_size, 'lstm_layers': self.lstm_layers, 'attention_heads': self.attention_heads } # Save scaler state if it exists if hasattr(self, 'scaler') and self.scaler is not None: checkpoint['scaler'] = self.scaler.state_dict() # Save with pickle_protocol=4 for better compatibility torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4) logger.info(f"Model saved to {path}") except Exception as e: logger.error(f"Error saving model: {e}") import traceback logger.error(traceback.format_exc()) def load(self, path="models/trading_agent_best_pnl.pt"): """Load a trained model with improved error handling for PyTorch 2.6 compatibility""" try: # First try to load with weights_only=False (for models saved with older PyTorch versions) try: logger.info(f"Attempting to load model with weights_only=False: {path}") checkpoint = torch.load(path, map_location=self.device, weights_only=False) logger.info("Model loaded successfully with weights_only=False") except Exception as e1: logger.warning(f"Failed to load with weights_only=False: {e1}") # Try with safe_globals context manager try: logger.info("Attempting to load with safe_globals context manager") import numpy as np from torch.serialization import safe_globals # Add numpy scalar to safe globals with safe_globals(['numpy._core.multiarray.scalar']): checkpoint = torch.load(path, map_location=self.device) logger.info("Model loaded successfully with safe_globals") except Exception as e2: logger.warning(f"Failed to load with safe_globals: {e2}") # Last resort: try with pickle_module=pickle logger.info("Attempting to load with pickle_module") import pickle checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False) logger.info("Model loaded successfully with pickle_module") # Load state dictionaries self.policy_net.load_state_dict(checkpoint['policy_net']) self.target_net.load_state_dict(checkpoint['target_net']) # Try to load optimizer state try: self.optimizer.load_state_dict(checkpoint['optimizer']) except Exception as e: logger.warning(f"Could not load optimizer state: {e}") # Load epsilon if available if 'epsilon' in checkpoint: self.epsilon = checkpoint['epsilon'] # Load architecture parameters if available if 'state_size' in checkpoint: self.state_size = checkpoint['state_size'] if 'action_size' in checkpoint: self.action_size = checkpoint['action_size'] if 'hidden_size' in checkpoint: self.hidden_size = checkpoint['hidden_size'] else: # If hidden_size not in checkpoint, infer from model try: self.hidden_size = self.policy_net.fc1.weight.shape[0] logger.info(f"Inferred hidden_size={self.hidden_size} from model") except: self.hidden_size = 256 # Default value logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}") if 'lstm_layers' in checkpoint: self.lstm_layers = checkpoint['lstm_layers'] else: self.lstm_layers = 2 # Default value if 'attention_heads' in checkpoint: self.attention_heads = checkpoint['attention_heads'] else: self.attention_heads = 4 # Default value logger.info(f"Model loaded successfully from {path}") except Exception as e: logger.error(f"Error loading model: {e}") import traceback logger.error(traceback.format_exc()) raise def add_chart_to_tensorboard(self, env, global_step): """Add trading chart to TensorBoard""" try: if len(env.data) < 10: return # Create chart image chart_img = create_candlestick_figure( env.data, env.trade_signals, window_size=100, title=f"Trading Chart - Step {global_step}" ) if chart_img is not None: # Convert PIL image to numpy array for TensorBoard chart_array = np.array(chart_img) # TensorBoard expects [C, H, W] format chart_array = np.transpose(chart_array, (2, 0, 1)) self.writer.add_image('Trading Chart', chart_array, global_step) # Add position information as text entry_price = env.entry_price if env.entry_price else 0.00 position_info = f""" **Current Position**: {env.position.upper()} **Entry Price**: ${entry_price:.2f} **Current Price**: ${env.data[-1]['close']:.2f} **Position Size**: ${env.position_size:.2f} **Unrealized PnL**: ${env.total_pnl:.2f} """ self.writer.add_text('Position', position_info, global_step) except Exception as e: logger.error(f"Error adding chart to TensorBoard: {str(e)}") # Continue without visualization rather than crashing async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" # Connect to MEXC websocket uri = "wss://stream.mexc.com/ws" async with websockets.connect(uri) as websocket: # Subscribe to kline data subscribe_msg = { "method": "SUBSCRIPTION", "params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"] } await websocket.send(json.dumps(subscribe_msg)) logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines") while True: try: response = await websocket.recv() data = json.loads(response) if 'data' in data: kline = data['data'] candle = { 'timestamp': kline['t'], 'open': float(kline['o']), 'high': float(kline['h']), 'low': float(kline['l']), 'close': float(kline['c']), 'volume': float(kline['v']) } yield candle except Exception as e: logger.error(f"Websocket error: {e}") # Try to reconnect await asyncio.sleep(5) break async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000): """Train the agent using historical and live data with GPU acceleration""" # Initialize statistics tracking stats = { 'episode_rewards': [], 'episode_lengths': [], 'balances': [], 'win_rates': [], 'episode_pnls': [], 'cumulative_pnl': [], 'drawdowns': [], 'prediction_accuracy': [], 'trade_analysis': [] } # Track best models best_reward = float('-inf') best_pnl = float('-inf') # Initialize TensorBoard writer if not already initialized if not hasattr(agent, 'writer') or agent.writer is None: agent.writer = SummaryWriter('runs/training') # Training loop for episode in range(num_episodes): try: # Reset environment state = env.reset() episode_reward = 0 prediction_loss = 0 # Episode loop for step in range(max_steps_per_episode): # Select action action = agent.select_action(state) # Take action try: next_state, reward, done, info = env.step(action) except Exception as e: logger.error(f"Error in step function: {e}") break # Store transition in replay memory agent.memory.push(state, action, reward, next_state, done) # Move to the next state state = next_state # Update episode reward episode_reward += reward # Learn from experience if len(agent.memory) > BATCH_SIZE: agent.learn() # Update price predictions periodically if step % 50 == 0: try: env.update_price_predictions() env.identify_optimal_trades() except Exception as e: logger.warning(f"Error updating predictions: {e}") # Add chart to TensorBoard periodically if step % 50 == 0 or (step == max_steps_per_episode - 1) or done: try: global_step = episode * max_steps_per_episode + step agent.add_chart_to_tensorboard(env, global_step) except Exception as e: logger.warning(f"Error adding chart to TensorBoard: {e}") # End episode if done if done: break # Update target network periodically if episode % TARGET_UPDATE == 0: agent.update_target_network() # Calculate win rate total_trades = env.win_count + env.loss_count win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0 # Train price predictor try: if episode % 5 == 0 and len(env.data) > 50: prediction_loss = env.train_price_predictor() except Exception as e: logger.warning(f"Error training price predictor: {e}") prediction_loss = 0 # Analyze trades try: trade_analysis = env.analyze_trades() stats['trade_analysis'].append(trade_analysis) except Exception as e: logger.warning(f"Error analyzing trades: {e}") trade_analysis = {} stats['trade_analysis'].append({}) # Calculate prediction accuracy prediction_accuracy = 0.0 try: if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0: if len(env.data) > 5: actual_prices = [candle['close'] for candle in env.data[-5:]] predicted = env.predicted_prices[:min(5, len(actual_prices))] errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])] prediction_accuracy = 100 * (1 - sum(errors) / len(errors)) except Exception as e: logger.warning(f"Error calculating prediction accuracy: {e}") # Log statistics stats['episode_rewards'].append(episode_reward) stats['episode_lengths'].append(step + 1) stats['balances'].append(env.balance) stats['win_rates'].append(win_rate) stats['episode_pnls'].append(env.episode_pnl) stats['cumulative_pnl'].append(env.total_pnl) stats['drawdowns'].append(env.max_drawdown * 100) stats['prediction_accuracy'].append(prediction_accuracy) # Log detailed trade analysis if trade_analysis: logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, " f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | " f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}") # Log to TensorBoard agent.writer.add_scalar('Reward/train', episode_reward, episode) agent.writer.add_scalar('Balance/train', env.balance, episode) agent.writer.add_scalar('WinRate/train', win_rate, episode) agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode) agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode) agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) agent.writer.add_scalar('PredictionLoss', prediction_loss, episode) agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) # Add final chart for this episode try: agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode) except Exception as e: logger.warning(f"Error adding final chart: {e}") logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, " f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, " f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, " f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%") # Save best model by reward if episode_reward > best_reward: best_reward = episode_reward agent.save("models/trading_agent_best_reward.pt") # Save best model by PnL if env.episode_pnl > best_pnl: best_pnl = env.episode_pnl agent.save("models/trading_agent_best_pnl.pt") # Save checkpoint if episode % 10 == 0: agent.save(f"models/trading_agent_episode_{episode}.pt") except Exception as e: logger.error(f"Error in episode {episode}: {e}") continue # Save final model agent.save("models/trading_agent_final.pt") # Plot training results plot_training_results(stats) return stats def plot_training_results(stats): """Plot detailed training results""" plt.figure(figsize=(20, 15)) # Plot rewards plt.subplot(3, 2, 1) plt.plot(stats['episode_rewards']) plt.title('Episode Rewards') plt.xlabel('Episode') plt.ylabel('Reward') # Plot balance plt.subplot(3, 2, 2) plt.plot(stats['balances']) plt.title('Account Balance') plt.xlabel('Episode') plt.ylabel('Balance ($)') # Plot win rate plt.subplot(3, 2, 3) plt.plot(stats['win_rates']) plt.title('Win Rate') plt.xlabel('Episode') plt.ylabel('Win Rate (%)') # Plot episode PnL plt.subplot(3, 2, 4) plt.plot(stats['episode_pnls']) plt.title('Episode PnL') plt.xlabel('Episode') plt.ylabel('PnL ($)') # Plot cumulative PnL plt.subplot(3, 2, 5) plt.plot(stats['cumulative_pnl']) plt.title('Cumulative PnL') plt.xlabel('Episode') plt.ylabel('Cumulative PnL ($)') # Plot drawdown plt.subplot(3, 2, 6) plt.plot(stats['drawdowns']) plt.title('Maximum Drawdown') plt.xlabel('Episode') plt.ylabel('Drawdown (%)') plt.tight_layout() plt.savefig('training_results.png') # Save statistics to CSV df = pd.DataFrame(stats) df.to_csv('training_stats.csv', index=False) logger.info("Training statistics saved to training_stats.csv and training_results.png") def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" total_reward = 0 total_profit = 0 total_trades = 0 winning_trades = 0 for episode in range(num_episodes): state = env.reset() episode_reward = 0 initial_balance = env.balance done = False while not done: # Select action (no exploration) action = agent.select_action(state, training=False) next_state, reward, done, info = env.step(action) state = next_state episode_reward += reward total_reward += episode_reward total_profit += env.balance - initial_balance # Count trades and wins for trade in env.trades: if 'pnl_percent' in trade: total_trades += 1 if trade['pnl_percent'] > 0: winning_trades += 1 # Calculate averages avg_reward = total_reward / num_episodes avg_profit = total_profit / num_episodes win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0 logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, " f"Win Rate={win_rate:.1f}%") return avg_reward, avg_profit, win_rate async def test_training(): """Test the training process with a small number of episodes""" logger.info("Starting training tests...") # Initialize exchange exchange = ccxt.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True, }) try: # Create environment with small initial balance for testing env = TradingEnvironment( exchange=exchange, symbol="ETH/USDT", timeframe="1m", leverage=MAX_LEVERAGE, initial_balance=100, # Small balance for testing demo=True # Always use demo mode for testing ) # Fetch initial data await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) # Create agent agent = Agent(state_size=STATE_SIZE, action_size=env.action_space) # Run a few test episodes test_episodes = 3 logger.info(f"Running {test_episodes} test episodes...") for episode in range(test_episodes): state = env.reset() episode_reward = 0 done = False step = 0 while not done and step < 100: # Limit steps for testing # Select action action = agent.select_action(state) # Take action next_state, reward, done, info = env.step(action) # Store experience agent.memory.push(state, action, reward, next_state, done) # Learn loss = agent.learn() state = next_state episode_reward += reward step += 1 # Print progress if step % 10 == 0: logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}") logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}") # Test model saving try: agent.save("models/test_model.pt") logger.info("Successfully saved model") except Exception as e: logger.error(f"Error saving model: {e}") logger.info("Training tests completed successfully") return True except Exception as e: logger.error(f"Training test failed: {e}") return False finally: await exchange.close() async def initialize_exchange(): """Initialize the exchange connection""" try: # Try to initialize with async support first try: exchange = ccxt.pro.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True }) logger.info(f"Exchange initialized with async support: {exchange.id}") except (AttributeError, ImportError): # Fall back to standard CCXT exchange = ccxt.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True }) logger.info(f"Exchange initialized with standard CCXT: {exchange.id}") return exchange except Exception as e: logger.error(f"Failed to initialize exchange: {e}") raise async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch historical OHLCV data from the exchange""" try: logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}") # Use the refactored fetch method data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) if not data: logger.warning("No historical data received") return data except Exception as e: logger.error(f"Failed to fetch historical data: {e}") return [] async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): """Run the trading bot in live mode with enhanced error handling""" logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe") logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") # Verify agent is properly initialized try: # Ensure agent has all required attributes if not hasattr(agent, 'hidden_size'): agent.hidden_size = 256 # Default value logger.warning("Agent missing hidden_size attribute, using default: 256") if not hasattr(agent, 'lstm_layers'): agent.lstm_layers = 2 # Default value logger.warning("Agent missing lstm_layers attribute, using default: 2") if not hasattr(agent, 'attention_heads'): agent.attention_heads = 4 # Default value logger.warning("Agent missing attention_heads attribute, using default: 4") logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}") except Exception as e: logger.error(f"Error checking agent configuration: {e}") # Continue anyway, as these are just informational attributes if not demo: # Confirm with user before starting live trading confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") if confirmation != "CONFIRM": logger.info("Live trading canceled by user") return # Initialize futures trading if not in demo mode try: await env.initialize_futures(exchange) logger.info(f"Futures trading initialized with {leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True # Initialize TensorBoard for monitoring if not hasattr(agent, 'writer') or agent.writer is None: from torch.utils.tensorboard import SummaryWriter # Fix the datetime usage here current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}') # Track performance metrics trades_count = 0 winning_trades = 0 total_profit = 0 max_drawdown = 0 peak_balance = env.balance step_counter = 0 prev_position = 'flat' # Create directory for trade logs os.makedirs('trade_logs', exist_ok=True) # Fix the datetime usage here current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") trade_log_path = f'trade_logs/trades_{current_time}.csv' with open(trade_log_path, 'w') as f: f.write("timestamp,action,price,position_size,balance,pnl\n") logger.info("Entering live trading loop...") try: while True: try: # Fetch latest candle data candle = await get_latest_candle(exchange, symbol) if candle is None: logger.warning("Failed to fetch latest candle, retrying in 5 seconds...") await asyncio.sleep(5) continue # Add new data to environment env.add_data(candle) # Get current state and select action state = env.get_state() # Verify state shape matches agent's expected input if state.shape[0] != agent.state_size: logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}") # Pad or truncate state to match expected size if state.shape[0] < agent.state_size: state = np.pad(state, (0, agent.state_size - state.shape[0])) else: state = state[:agent.state_size] action = agent.select_action(state, training=False) # Ensure action is valid if action >= agent.action_size: logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}") action = agent.action_size - 1 # Log action action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}") # Execute action if not demo: # Execute real trade on exchange current_price = env.data[-1]['close'] trade_result = await env.execute_real_trade(exchange, action, current_price) if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False): error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed' logger.error(f"Trade execution failed: {error_msg}") # Continue with simulated trade for tracking purposes # Update environment with action (simulated in demo mode) try: next_state, reward, done, info = env.step(action) except ValueError as e: # Handle case where step returns 3 values instead of 4 if "not enough values to unpack" in str(e): logger.warning("Step function returned 3 values instead of 4, creating info dict") next_state, reward, done = env.step(action) info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': env.current_price, 'balance': env.balance, 'position': env.position, 'pnl': env.total_pnl } else: raise # Log trade if position changed if env.position != prev_position: trades_count += 1 if env.last_trade_profit > 0: winning_trades += 1 total_profit += env.last_trade_profit # Log trade details with open(trade_log_path, 'a') as f: f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n") logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") # Update performance metrics if env.balance > peak_balance: peak_balance = env.balance current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 if current_drawdown > max_drawdown: max_drawdown = current_drawdown # Update TensorBoard metrics step_counter += 1 agent.writer.add_scalar('Live/Balance', env.balance, step_counter) agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) # Update chart visualization if step_counter % 5 == 0 or env.position != prev_position: agent.add_chart_to_tensorboard(env, step_counter) # Log performance summary if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) performance_text = f""" **Live Trading Performance** Balance: ${env.balance:.2f} Total PnL: ${env.total_pnl:.2f} Trades: {trades_count} Win Rate: {win_rate:.1f}% Max Drawdown: {max_drawdown*100:.1f}% """ agent.writer.add_text('Performance', performance_text, step_counter) prev_position = env.position # Wait for next candle logger.info(f"Waiting for next candle... (Step {step_counter})") await asyncio.sleep(10) # Check every 10 seconds except Exception as e: logger.error(f"Error in live trading loop: {str(e)}") import traceback logger.error(traceback.format_exc()) logger.info("Continuing after error...") await asyncio.sleep(30) # Wait longer after an error except KeyboardInterrupt: logger.info("Live trading stopped by user") # Final performance report if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 logger.info(f"Trading session summary:") logger.info(f"Total trades: {trades_count}") logger.info(f"Win rate: {win_rate:.1f}%") logger.info(f"Final balance: ${env.balance:.2f}") logger.info(f"Total profit: ${total_profit:.2f}") logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%") logger.info(f"Trade log saved to: {trade_log_path}") async def get_latest_candle(exchange, symbol): """Get the latest candle data""" try: # Use the refactored fetch method with limit=1 data = await fetch_ohlcv_data(exchange, symbol, "1m", 1) if data and len(data) > 0: return data[0] else: logger.warning("No candle data received") return None except Exception as e: logger.error(f"Failed to fetch latest candle: {e}") return None async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): """Fetch OHLCV data with proper handling for both async and standard CCXT""" try: # Check if exchange has fetchOHLCV method if not hasattr(exchange, 'fetchOHLCV'): logger.error("Exchange does not support OHLCV data fetching") return [] # Handle different CCXT versions if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): # Use async method if available ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit) else: # Use synchronous method with run_in_executor loop = asyncio.get_event_loop() ohlcv = await loop.run_in_executor( None, lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) ) # Convert to list of dictionaries data = [] for candle in ohlcv: timestamp, open_price, high, low, close, volume = candle data.append({ 'timestamp': timestamp, 'open': open_price, 'high': high, 'low': low, 'close': close, 'volume': volume }) logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})") return data except Exception as e: logger.error(f"Failed to fetch OHLCV data: {e}") return [] async def initialize_websocket_data_stream(symbol="ETH/USDT", timeframe="1m"): """Initialize a WebSocket connection for real-time trading data Args: symbol: Trading pair symbol (e.g., "ETH/USDT") timeframe: Timeframe for candle aggregation (e.g., "1m") Returns: Tuple of (websocket, candle_data) where websocket is the BinanceWebSocket instance and candle_data is a dict to track ongoing candle formation """ try: # Initialize historical data handler to get initial data historical_data = BinanceHistoricalData() # Convert timeframe to seconds for historical data if timeframe == "1m": interval_seconds = 60 elif timeframe == "5m": interval_seconds = 300 elif timeframe == "15m": interval_seconds = 900 elif timeframe == "1h": interval_seconds = 3600 else: interval_seconds = 60 # Default to 1m # Fetch initial historical data initial_data = historical_data.get_historical_candles( symbol=symbol, interval_seconds=interval_seconds, limit=1000 # Get 1000 candles for good history ) # Convert pandas DataFrame to list of dictionaries for our environment initial_candles = [] if not initial_data.empty: for _, row in initial_data.iterrows(): candle = { 'timestamp': int(row['timestamp'].timestamp() * 1000), 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) } initial_candles.append(candle) logger.info(f"Loaded {len(initial_candles)} historical candles") else: logger.warning("No historical data fetched") # Initialize WebSocket for real-time data binance_ws = BinanceWebSocket(symbol.replace('/', '')) await binance_ws.connect() # Track the current candle data current_minute = None current_candle = None logger.info(f"WebSocket for {symbol} initialized successfully") return binance_ws, initial_candles except Exception as e: logger.error(f"Failed to initialize WebSocket data stream: {e}") logger.error(traceback.format_exc()) return None, [] async def process_websocket_ticks(websocket, env, agent=None, demo=True, timeframe="1m"): """Process real-time ticks from WebSocket and aggregate them into candles Args: websocket: BinanceWebSocket instance env: TradingEnvironment instance agent: Agent instance (optional, for live trading) demo: Whether to run in demo mode timeframe: Timeframe for candle aggregation """ # Initialize variables for candle aggregation current_candle = None current_minute = None trades_count = 0 step_counter = 0 # For tracking sudden price movements last_prices = [] price_movement_threshold = 0.5 # 0.5% movement threshold volume_spike_threshold = 2.0 # 2x average volume recent_volumes = [] try: logger.info("Starting WebSocket tick processing...") while websocket.running: # Get the next tick from WebSocket tick = await websocket.receive() if tick is None: # No data received, wait and try again await asyncio.sleep(0.1) continue # Extract data from tick timestamp = tick.get('timestamp') price = tick.get('price') volume = tick.get('volume') if timestamp is None or price is None: logger.warning(f"Invalid tick data received: {tick}") continue # Track price movement for significant changes last_prices.append(price) if len(last_prices) > 20: last_prices.pop(0) # Track volumes for volume spikes recent_volumes.append(volume) if len(recent_volumes) > 20: recent_volumes.pop(0) # Check for significant price movement if len(last_prices) >= 5: price_change_pct = abs(price - last_prices[0]) / last_prices[0] * 100 avg_volume = np.mean(recent_volumes[:-1]) if len(recent_volumes) > 1 else volume volume_ratio = volume / avg_volume if avg_volume > 0 else 1.0 # Log significant movements if price_change_pct > price_movement_threshold: logger.info(f"Significant price movement detected: {price_change_pct:.2f}% change") if volume_ratio > volume_spike_threshold: logger.info(f"Volume spike detected: {volume_ratio:.2f}x average volume") # Force more frequent trading decisions on significant movements if (price_change_pct > price_movement_threshold or volume_ratio > volume_spike_threshold) and agent is not None and current_candle is not None: # Create a temporary candle with current data temp_candle = current_candle.copy() temp_candle['close'] = price # Update with latest price # Add to environment temporarily env.add_data(temp_candle) # Get action state = env.get_state() # Force exploitation (no exploration) during significant movements action = agent.select_action(state, training=False) # Execute action in environment next_state, reward, done, info = env.step(action) # Log trading activity action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" logger.info(f"Significant movement action: {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}") # Convert timestamp to datetime tick_time = datetime.datetime.fromtimestamp(timestamp / 1000) # For 1-minute candles, track the minute if timeframe == "1m": tick_minute = tick_time.replace(second=0, microsecond=0) # If this is a new minute, close the current candle and start a new one if current_minute is None or tick_minute > current_minute: # If there was a previous candle, add it to the environment if current_candle is not None: # Add the candle to the environment env.add_data(current_candle) # Process trading decisions if agent is provided if agent is not None: state = env.get_state() action = agent.select_action(state, training=False) # Execute action in environment next_state, reward, done, info = env.step(action) # Log trading activity action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" logger.info(f"Step {step_counter}: Action {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}") step_counter += 1 # Start a new candle current_minute = tick_minute current_candle = { 'timestamp': int(current_minute.timestamp() * 1000), 'open': price, 'high': price, 'low': price, 'close': price, 'volume': volume } logger.debug(f"Started new candle at {current_minute}") else: # Update the current candle current_candle['high'] = max(current_candle['high'], price) current_candle['low'] = min(current_candle['low'], price) current_candle['close'] = price current_candle['volume'] += volume # For other timeframes, implement similar logic # ... except asyncio.CancelledError: logger.info("WebSocket processing canceled") except Exception as e: logger.error(f"Error in WebSocket tick processing: {e}") logger.error(traceback.format_exc()) finally: # Make sure to close the WebSocket if websocket: await websocket.close() logger.info("WebSocket connection closed") # Add this near the top of the file, after imports def ensure_pytorch_compatibility(): """Ensure compatibility with PyTorch 2.6+ for model loading""" try: import torch from torch.serialization import add_safe_globals import numpy as np # Add numpy scalar to safe globals for PyTorch 2.6+ add_safe_globals(['numpy._core.multiarray.scalar']) logger.info("Added numpy scalar to PyTorch safe globals") except (ImportError, AttributeError) as e: logger.warning(f"Could not configure PyTorch compatibility: {e}") logger.warning("This might cause issues with model loading in PyTorch 2.6+") # Call this function at the start of the main function async def main(): # Ensure PyTorch compatibility ensure_pytorch_compatibility() parser = argparse.ArgumentParser(description='Trading Bot') parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train', help='Operation mode: train, eval, or live') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training or evaluation') parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true', help='Run in demo mode (paper trading) if true') parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe (1m, 5m, 15m, 1h, etc.)') parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading') parser.add_argument('--model', type=str, default=None, help='Path to model file for evaluation or live trading') parser.add_argument('--use-websocket', action='store_true', help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)') parser.add_argument('--dashboard', action='store_true', help='Enable Dash dashboard visualization for real-time trading') args = parser.parse_args() # Convert string boolean to actual boolean demo_mode = args.demo.lower() == 'true' # Get device (GPU or CPU) device = get_device() exchange = None try: # Initialize exchange exchange = await initialize_exchange() # Create environment env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode) if args.mode == 'train': # Fetch initial data for training await env.fetch_initial_data(exchange, args.symbol,args.timeframe, 1000) # Create agent with consistent parameters # Note: Using STATE_SIZE and action_size=4 for consistency agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) # Train the agent logger.info(f"Starting training for {args.episodes} episodes...") stats = await train_agent(agent, env, num_episodes=args.episodes) elif args.mode == 'eval' or args.mode == 'live': # Fetch initial data for the specified symbol and timeframe await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000) # Determine model path model_path = args.model if args.model else "models/trading_agent_best_pnl.pt" if not os.path.exists(model_path): logger.error(f"Model file not found: {model_path}") return # Create agent with default parameters agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) # Try to load the model try: # Add numpy scalar to safe globals before loading import numpy as np from torch.serialization import add_safe_globals # Add numpy scalar to safe globals add_safe_globals(['numpy._core.multiarray.scalar']) # Load the model agent.load(model_path) logger.info(f"Model loaded successfully from {model_path}") except Exception as e: logger.error(f"Failed to load model: {e}") # Ask user if they want to continue with a new model if args.mode == 'live': confirmation = input("Failed to load model. Continue with a new model? (y/n): ") if confirmation.lower() != 'y': logger.info("Live trading canceled by user") return logger.info("Continuing with a new model") else: logger.info("Continuing evaluation with a new model") if args.mode == 'eval': # Evaluate the agent logger.info("Evaluating agent...") avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes) elif args.mode == 'live': # Start live trading logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe") logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x") if args.use_websocket: logger.info("Using Binance WebSocket for real-time data") await live_trading_with_websocket( agent=agent, env=env, symbol=args.symbol, timeframe=args.timeframe, demo=demo_mode, leverage=args.leverage, use_dashboard=args.dashboard ) else: logger.info("Using CCXT for real-time data") await live_trading( agent=agent, env=env, exchange=exchange, symbol=args.symbol, timeframe=args.timeframe, demo=demo_mode, leverage=args.leverage ) except Exception as e: logger.error(f"Error in main function: {e}") import traceback logger.error(traceback.format_exc()) finally: # Clean up exchange connection if exchange: try: if hasattr(exchange, 'close'): await exchange.close() elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'): await exchange.client.close() logger.info("Exchange connection closed") except Exception as e: logger.warning(f"Could not properly close exchange connection: {e}") # Add this function near the top with other utility functions def create_candlestick_figure(data, trade_signals, window_size=100, title=""): """Create a candlestick chart with trade signals for TensorBoard visualization""" if len(data) < 10: return None try: # Create figure fig = plt.figure(figsize=(12, 8)) # Prepare data for plotting df = pd.DataFrame(data[-window_size:]) df['date'] = pd.to_datetime(df['timestamp'], unit='ms') df.set_index('date', inplace=True) # Create subplot grid gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) price_ax = plt.subplot(gs[0]) volume_ax = plt.subplot(gs[1], sharex=price_ax) # Plot candlesticks - use a simpler approach if mplfinance fails try: # Use a different style or approach that doesn't use 'type' parameter mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo') except Exception as e: logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot") # Fallback to simple plot price_ax.plot(df.index, df['close'], label='Price') volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5) # Add trade signals for signal in trade_signals: try: timestamp = pd.to_datetime(signal['timestamp'], unit='ms') price = signal['price'] if signal['type'] == 'buy': price_ax.plot(timestamp, price, '^', color='green', markersize=10) elif signal['type'] == 'sell': price_ax.plot(timestamp, price, 'v', color='red', markersize=10) elif signal['type'] == 'close_long': price_ax.plot(timestamp, price, 'x', color='gold', markersize=10) elif signal['type'] == 'close_short': price_ax.plot(timestamp, price, 'x', color='black', markersize=10) elif 'stop_loss' in signal['type']: price_ax.plot(timestamp, price, 'X', color='purple', markersize=10) elif 'take_profit' in signal['type']: price_ax.plot(timestamp, price, '*', color='cyan', markersize=10) except Exception as e: logger.warning(f"Error plotting signal: {e}") continue # Add balance and PnL annotation if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]: balance = trade_signals[-1]['balance'] pnl = trade_signals[-1]['pnl'] price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}", xy=(0.02, 0.95), xycoords='axes fraction', bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)) # Set title and format price_ax.set_title(title) fig.tight_layout() # Convert to image buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) plt.close(fig) img = Image.open(buf) return img except Exception as e: logger.error(f"Error creating chart: {str(e)}") return None async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50, use_dashboard=False): """Run the trading bot in live mode using Binance WebSocket for real-time data Args: agent: The trading agent to use for decision making env: The trading environment symbol: The trading pair symbol (e.g., "ETH/USDT") timeframe: The candlestick timeframe (e.g., "1m") demo: Whether to run in demo mode (paper trading) leverage: The leverage to use for trading use_dashboard: Whether to display the real-time dashboard Returns: None """ logger.info(f"Starting live trading with WebSocket for {symbol} on {timeframe} timeframe") logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") # If not demo mode, confirm with user before starting live trading if not demo: confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") if confirmation != "CONFIRM": logger.info("Live trading canceled by user") return # Initialize TensorBoard for monitoring if not hasattr(agent, 'writer') or agent.writer is None: from torch.utils.tensorboard import SummaryWriter current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}') # Initialize Dash dashboard if enabled dashboard = None if use_dashboard: try: dashboard = TradingDashboard(symbol) dashboard_started = dashboard.start() # Start the dashboard in a separate thread if dashboard_started: logger.info(f"Trading dashboard enabled at http://localhost:8060") else: logger.warning("Failed to start trading dashboard, continuing without visualization") dashboard = None except Exception as e: logger.error(f"Error initializing dashboard: {e}") logger.error(traceback.format_exc()) dashboard = None # Track performance metrics trades_count = 0 winning_trades = 0 total_profit = 0 max_drawdown = 0 peak_balance = env.balance step_counter = 0 # Create directory for trade logs os.makedirs('trade_logs', exist_ok=True) current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") trade_log_path = f'trade_logs/trades_ws_{current_time}.csv' with open(trade_log_path, 'w') as f: f.write("timestamp,action,price,position_size,balance,pnl\n") try: # Initialize WebSocket connection and get historical data websocket, initial_candles = await initialize_websocket_data_stream(symbol, timeframe) if websocket is None or not initial_candles: logger.error("Failed to initialize WebSocket data stream") return # Load initial historical data into the environment logger.info(f"Loading {len(initial_candles)} initial candles into environment") for candle in initial_candles: env.add_data(candle) # Reset environment with historical data env.reset() # Update dashboard with initial data if enabled if dashboard: dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals) # Initialize futures trading if not in demo mode exchange = None if not demo: # Import ccxt for exchange initialization import ccxt.async_support as ccxt_async # Initialize exchange for order execution exchange = await initialize_exchange() if exchange: try: await env.initialize_futures(exchange) logger.info(f"Futures trading initialized with {leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True # Start WebSocket processing in the background websocket_task = asyncio.create_task( process_websocket_ticks(websocket, env, agent, demo, timeframe) ) # Main tracking loop prev_position = 'flat' while True: try: # Check if position has changed if env.position != prev_position: trades_count += 1 if hasattr(env, 'last_trade_profit') and env.last_trade_profit > 0: winning_trades += 1 if hasattr(env, 'last_trade_profit'): total_profit += env.last_trade_profit # Log trade details current_time = datetime.datetime.now().isoformat() action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE" with open(trade_log_path, 'a') as f: f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n") logger.info(f"Trade executed: {action_name} at ${env.current_price:.2f}, PnL: ${getattr(env, 'last_trade_profit', 0):.2f}") # Update performance metrics if env.balance > peak_balance: peak_balance = env.balance current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 if current_drawdown > max_drawdown: max_drawdown = current_drawdown # Update TensorBoard metrics step_counter += 1 if step_counter % 10 == 0: # Update every 10 steps agent.writer.add_scalar('Live/Balance', env.balance, step_counter) agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) # Update chart visualization if step_counter % 30 == 0 or env.position != prev_position: agent.add_chart_to_tensorboard(env, step_counter) # Log performance summary if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) performance_text = f""" **Live Trading Performance** Balance: ${env.balance:.2f} Total PnL: ${env.total_pnl:.2f} Trades: {trades_count} Win Rate: {win_rate:.1f}% Max Drawdown: {max_drawdown*100:.1f}% """ agent.writer.add_text('Performance', performance_text, step_counter) # Update the dashboard with latest data if enabled if dashboard: dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals) prev_position = env.position # Sleep for a short time to prevent CPU hogging await asyncio.sleep(1) except Exception as e: logger.error(f"Error in live trading monitor loop: {str(e)}") logger.error(traceback.format_exc()) await asyncio.sleep(10) # Wait longer after an error except KeyboardInterrupt: logger.info("Live trading stopped by user") # Cancel the WebSocket task if 'websocket_task' in locals() and not websocket_task.done(): websocket_task.cancel() try: await websocket_task except asyncio.CancelledError: pass # Close the exchange connection if it exists if exchange: await exchange.close() # Final performance report if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 logger.info(f"Trading session summary:") logger.info(f"Total trades: {trades_count}") logger.info(f"Win rate: {win_rate:.1f}%") logger.info(f"Final balance: ${env.balance:.2f}") logger.info(f"Total profit: ${total_profit:.2f}") except Exception as e: logger.error(f"Critical error in live trading: {str(e)}") logger.error(traceback.format_exc()) finally: # Make sure to close WebSocket if 'websocket' in locals() and websocket: await websocket.close() # Close the exchange connection if it exists if 'exchange' in locals() and exchange: await exchange.close() def ensure_pytorch_compatibility(): """Check and fix common PyTorch compatibility issues""" try: import torch.serialization import pickle # Register safe pickles to handle the numpy scalar warning if hasattr(torch.serialization, 'add_safe_globals'): torch.serialization.add_safe_globals([('numpy._core.multiarray.scalar', np.ndarray)]) torch.serialization.add_safe_globals([('numpy.core.multiarray.scalar', np.ndarray)]) torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar']) torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar']) logger.info("PyTorch safe globals registered for compatibility") else: logger.warning("PyTorch serialization module doesn't have add_safe_globals method") except Exception as e: logger.warning(f"PyTorch compatibility check failed: {e}") class TradingDashboard: """Dashboard for visualizing trading activity with Dash""" def __init__(self, symbol="ETH/USDT"): self.symbol = symbol self.env = None self.candles = [] self.trade_signals = [] # Create Dash app self.app = dash.Dash(__name__, suppress_callback_exceptions=True) # Create basic layout self.app.layout = html.Div([ # Store components for data html.Div(id='candle-store', style={'display': 'none'}), html.Div(id='signal-store', style={'display': 'none'}), # Header html.H1(f"Trading Dashboard - {symbol}", style={'textAlign': 'center'}), # Main content html.Div([ # Chart html.Div([ dcc.Graph(id='candlestick-chart', style={'height': '70vh'}), dcc.Interval(id='interval-component', interval=5*1000, n_intervals=0) ], style={'width': '70%', 'display': 'inline-block'}), # Trading info html.Div([ html.Div([ html.H3("Account Info"), html.Div(id='account-info') ]), html.Div([ html.H3("Recent Trades"), html.Div(id='recent-trades') ]) ], style={'width': '30%', 'display': 'inline-block', 'verticalAlign': 'top'}) ]) ]) # Setup callbacks self._setup_callbacks() # Thread for running the server self.thread = None self.is_running = False def _setup_callbacks(self): @self.app.callback( Output('candlestick-chart', 'figure'), [Input('interval-component', 'n_intervals'), Input('candle-store', 'children'), Input('signal-store', 'children')] ) def update_chart(n, candles_json, signals_json): # Parse JSON data candles = json.loads(candles_json) if candles_json else [] signals = json.loads(signals_json) if signals_json else [] # Create figure with subplots fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1, row_heights=[0.7, 0.3]) if candles: # Convert to dataframe df = pd.DataFrame(candles[-100:]) # Show last 100 candles df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') # Add candlestick trace fig.add_trace( go.Candlestick( x=df['timestamp'], open=df['open'], high=df['high'], low=df['low'], close=df['close'], name='Price' ), row=1, col=1 ) # Add volume trace fig.add_trace( go.Bar( x=df['timestamp'], y=df['volume'], name='Volume' ), row=2, col=1 ) # Add trade signals for signal in signals: if signal['timestamp'] >= df['timestamp'].iloc[0].timestamp() * 1000: signal_time = pd.to_datetime(signal['timestamp'], unit='ms') marker_color = 'green' if signal['type'] == 'buy' else 'red' if signal['type'] == 'sell' else 'orange' marker_symbol = 'triangle-up' if signal['type'] == 'buy' else 'triangle-down' if signal['type'] == 'sell' else 'circle' # Add marker for signal fig.add_trace( go.Scatter( x=[signal_time], y=[signal['price']], mode='markers', marker=dict( color=marker_color, size=12, symbol=marker_symbol ), name=signal['type'].capitalize(), showlegend=False ), row=1, col=1 ) # Update layout fig.update_layout( title=f'{self.symbol} Trading Chart', xaxis_rangeslider_visible=False, template='plotly_dark' ) return fig @self.app.callback( [Output('account-info', 'children'), Output('recent-trades', 'children')], [Input('interval-component', 'n_intervals')] ) def update_account_info(n): if not self.env: return "No data available", "No trades available" # Account info account_info = html.Div([ html.P(f"Balance: ${self.env.balance:.2f}"), html.P(f"PnL: ${self.env.total_pnl:.2f}", style={'color': 'green' if self.env.total_pnl > 0 else 'red' if self.env.total_pnl < 0 else 'white'}), html.P(f"Position: {self.env.position.upper()}") ]) # Recent trades if hasattr(self.env, 'trades') and self.env.trades: # Get last 5 trades recent_trades = [] for trade in reversed(self.env.trades[-5:]): trade_card = html.Div([ html.P(f"{trade['action'].upper()} at ${trade['price']:.2f}"), html.P(f"PnL: ${trade['pnl']:.2f}", style={'color': 'green' if trade['pnl'] > 0 else 'red' if trade['pnl'] < 0 else 'white'}) ], style={'border': '1px solid #ddd', 'padding': '10px', 'margin-bottom': '5px'}) recent_trades.append(trade_card) else: recent_trades = [html.P("No trades yet")] return account_info, recent_trades def update_data(self, env=None, candles=None, trade_signals=None): """Update dashboard data""" if env: self.env = env if candles: self.candles = candles if trade_signals: self.trade_signals = trade_signals # Update store components if hasattr(self.app, 'layout'): self.app.layout.children[0].children = json.dumps(self.candles) self.app.layout.children[1].children = json.dumps(self.trade_signals) def start(self, host='localhost', port=8060): """Start the dashboard server in a separate thread""" if not self.is_running: # First check if the port is already in use sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) port_available = False # Try the initial port and a few alternatives if needed for attempt_port in range(port, port + 10): try: sock.bind((host, attempt_port)) port_available = True port = attempt_port break except socket.error: logger.warning(f"Port {attempt_port} is already in use") sock.close() if not port_available: logger.error("Could not find an available port for dashboard") return False # Create and start the thread self.thread = Thread(target=self._run_server, args=(host, port)) self.thread.daemon = True # This ensures the thread will exit when the main program does self.thread.start() self.is_running = True logger.info(f"Trading dashboard started at http://{host}:{port}") # Verify the thread actually started if not self.thread.is_alive(): logger.error("Dashboard thread failed to start") return False # Wait a short time to let the server initialize time.sleep(1.0) return True return False def _run_server(self, host, port): """Run the Dash server""" try: logger.info(f"Starting Dash server on {host}:{port}") self.app.run_server(debug=False, host=host, port=port, use_reloader=False, threaded=True) except Exception as e: logger.error(f"Error running dashboard server: {e}") self.is_running = False if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: logger.info("Program terminated by user")