import os import time import json import numpy as np import pandas as pd from datetime import datetime import random import logging import asyncio import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from collections import deque, namedtuple from dotenv import load_dotenv import ccxt import websockets from torch.utils.tensorboard import SummaryWriter import torch.cuda.amp as amp # Add this import at the top from sklearn.preprocessing import MinMaxScaler import copy import argparse import traceback import io import matplotlib.dates as mdates from matplotlib.figure import Figure from PIL import Image import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime from realtime import BinanceWebSocket, BinanceHistoricalData from datetime import datetime as dt # Add Dash-related imports import dash from dash import html, dcc, callback_context from dash.dependencies import Input, Output, State import plotly.graph_objects as go from plotly.subplots import make_subplots from threading import Thread # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()] ) logger = logging.getLogger("trading_bot") # Load environment variables load_dotenv() MEXC_API_KEY = os.getenv('MEXC_API_KEY') MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY') # Constants INITIAL_BALANCE = 100 # USD MAX_LEVERAGE = 100 STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5% MEMORY_SIZE = 100000 BATCH_SIZE = 64 GAMMA = 0.99 # Discount factor EPSILON_START = 1.0 EPSILON_END = 0.05 EPSILON_DECAY = 10000 STATE_SIZE = 64 # Size of our state representation LEARNING_RATE = 1e-4 TARGET_UPDATE = 10 # Update target network every 10 episodes # Experience replay tuple Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done']) # Add this function near the top of the file, after the imports but before any classes def find_local_extrema(prices, window=5): """Find local minima (bottoms) and maxima (tops) in price data""" bottoms = [] tops = [] if len(prices) < window * 2 + 1: return bottoms, tops for i in range(window, len(prices) - window): # Check if this is a local minimum (bottom) if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \ all(prices[i] <= prices[i+j] for j in range(1, window+1)): bottoms.append(i) # Check if this is a local maximum (top) if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \ all(prices[i] >= prices[i+j] for j in range(1, window+1)): tops.append(i) return bottoms, tops class ReplayMemory: def __init__(self, capacity): self.memory = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.memory.append(Experience(state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class DQN(nn.Module): def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() self.state_size = state_size self.hidden_size = hidden_size self.lstm_layers = lstm_layers # Initial feature extraction self.fc1 = nn.Linear(state_size, hidden_size) # Use LayerNorm instead of BatchNorm for more stability with varying batch sizes self.ln1 = nn.LayerNorm(hidden_size) self.dropout1 = nn.Dropout(0.2) # LSTM layer for sequential data self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2) # Attention mechanism self.attention = nn.MultiheadAttention(hidden_size, attention_heads) # Output layers with increased capacity self.fc2 = nn.Linear(hidden_size, hidden_size) self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm self.dropout2 = nn.Dropout(0.2) self.fc3 = nn.Linear(hidden_size, hidden_size // 2) # Dueling DQN architecture self.value_stream = nn.Linear(hidden_size // 2, 1) self.advantage_stream = nn.Linear(hidden_size // 2, action_size) # Transformer encoder for more complex pattern recognition encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) def forward(self, x): batch_size = x.size(0) if x.dim() > 1 else 1 # Ensure input has correct shape if x.dim() == 1: x = x.unsqueeze(0) # Add batch dimension # Check if state size matches expected input size if x.size(1) != self.state_size: # Handle mismatched input by either truncating or padding if x.size(1) > self.state_size: x = x[:, :self.state_size] # Truncate else: # Pad with zeros padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device) x = torch.cat([x, padding], dim=1) # Initial feature extraction x = self.fc1(x) x = F.relu(self.ln1(x)) # LayerNorm works with any batch size x = self.dropout1(x) # Reshape for LSTM x_lstm = x.unsqueeze(1) if x.dim() == 2 else x # Process through LSTM lstm_out, _ = self.lstm(x_lstm) lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1] # Process through transformer for more complex patterns transformer_input = x.unsqueeze(1) if x.dim() == 2 else x transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1)) transformer_out = transformer_out.transpose(0, 1).mean(dim=1) # Combine LSTM and transformer outputs x = lstm_out + transformer_out # Final layers x = self.fc2(x) x = F.relu(self.ln2(x)) # LayerNorm works with any batch size x = self.dropout2(x) x = F.relu(self.fc3(x)) # Dueling architecture value = self.value_stream(x) advantages = self.advantage_stream(x) qvals = value + (advantages - advantages.mean(dim=1, keepdim=True)) return qvals class PricePredictionModel(nn.Module): def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): super(PricePredictionModel, self).__init__() self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2) self.fc = nn.Linear(hidden_size, output_size) self.scaler = MinMaxScaler(feature_range=(0, 1)) self.is_fitted = False def forward(self, x): # x shape: [batch_size, seq_len, 1] lstm_out, _ = self.lstm(x) # Use the last time step output predictions = self.fc(lstm_out[:, -1, :]) return predictions def preprocess(self, data): # Reshape data for scaler data_reshaped = np.array(data).reshape(-1, 1) # Fit scaler if not already fitted if not self.is_fitted: self.scaler.fit(data_reshaped) self.is_fitted = True # Transform data scaled_data = self.scaler.transform(data_reshaped) return scaled_data def postprocess(self, scaled_predictions): # Inverse transform to get actual price values return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten() def predict_next_candles(self, price_history, num_candles=5): if len(price_history) < 30: # Need enough history return np.zeros(num_candles) # Preprocess data scaled_data = self.preprocess(price_history) # Create sequence sequence = scaled_data[-30:].reshape(1, 30, 1) sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device) # Get predictions with torch.no_grad(): scaled_predictions = self(sequence_tensor).cpu().numpy()[0] # Postprocess predictions predictions = self.postprocess(scaled_predictions) return predictions def train_on_new_data(self, price_history, optimizer, epochs=10): if len(price_history) < 35: # Need enough history for training return 0.0 # Preprocess data scaled_data = self.preprocess(price_history) # Create sequences and targets sequences = [] targets = [] for i in range(len(scaled_data) - 35): # Sequence: 30 time steps seq = scaled_data[i:i+30] # Target: next 5 time steps target = scaled_data[i+30:i+35].flatten() sequences.append(seq) targets.append(target) if not sequences: # If no sequences were created return 0.0 # Convert to tensors sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device) targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device) # Training loop total_loss = 0 for _ in range(epochs): # Forward pass predictions = self(sequences_tensor) # Calculate loss loss = F.mse_loss(predictions, targets_tensor) # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / epochs class TradingEnvironment: def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True): """Initialize the trading environment""" self.initial_balance = initial_balance self.balance = initial_balance self.window_size = window_size self.demo = demo self.data = [] self.position = 'flat' # 'flat', 'long', or 'short' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 self.trades = [] self.win_count = 0 self.loss_count = 0 self.total_pnl = 0.0 self.episode_pnl = 0.0 self.peak_balance = initial_balance self.max_drawdown = 0.0 self.current_step = 0 self.current_price = 0 # For tracking signals for visualization self.trade_signals = [] # Initialize features self.features = { 'price': [], 'volume': [], 'rsi': [], 'macd': [], 'macd_signal': [], 'macd_hist': [], 'bollinger_upper': [], 'bollinger_mid': [], 'bollinger_lower': [], 'stoch_k': [], 'stoch_d': [], 'ema_9': [], 'ema_21': [], 'atr': [] } # Initialize price predictor self.price_predictor = None self.predicted_prices = np.array([]) # Initialize optimal trade tracking self.optimal_bottoms = [] self.optimal_tops = [] self.optimal_signals = np.array([]) # Add these new attributes self.leverage = MAX_LEVERAGE self.futures_symbol = "ETH_USDT" # Example futures symbol self.position_mode = "hedge" # For simultaneous long/short positions self.margin_mode = "cross" # Cross margin mode def reset(self): """Reset the environment to initial state""" self.balance = self.initial_balance self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 self.trades = [] self.win_count = 0 self.loss_count = 0 self.episode_pnl = 0.0 self.peak_balance = self.initial_balance self.max_drawdown = 0.0 self.current_step = 0 # Keep data but reset current position if len(self.data) > self.window_size: self.current_step = self.window_size self.current_price = self.data[self.current_step]['close'] # Reset trade signals self.trade_signals = [] return self.get_state() def add_data(self, candle): """Add a new candle to the data""" self.data.append(candle) self._update_features() self.current_price = candle['close'] def _initialize_features(self): """Initialize technical indicators and features""" if len(self.data) < 30: return # Convert data to pandas DataFrame for easier calculation df = pd.DataFrame(self.data) # Basic price and volume self.features['price'] = df['close'].values self.features['volume'] = df['volume'].values # Calculate RSI (14 periods) delta = df['close'].diff() gain = delta.where(delta > 0, 0).rolling(window=14).mean() loss = -delta.where(delta < 0, 0).rolling(window=14).mean() rs = gain / loss self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values # Calculate MACD ema12 = df['close'].ewm(span=12, adjust=False).mean() ema26 = df['close'].ewm(span=26, adjust=False).mean() macd = ema12 - ema26 signal = macd.ewm(span=9, adjust=False).mean() self.features['macd'] = macd.values self.features['macd_signal'] = signal.values self.features['macd_hist'] = (macd - signal).values # Calculate Bollinger Bands sma20 = df['close'].rolling(window=20).mean() std20 = df['close'].rolling(window=20).std() self.features['bollinger_upper'] = (sma20 + 2 * std20).values self.features['bollinger_mid'] = sma20.values self.features['bollinger_lower'] = (sma20 - 2 * std20).values # Calculate Stochastic Oscillator low_14 = df['low'].rolling(window=14).min() high_14 = df['high'].rolling(window=14).max() k = 100 * ((df['close'] - low_14) / (high_14 - low_14)) self.features['stoch_k'] = k.values self.features['stoch_d'] = k.rolling(window=3).mean().values # Calculate EMAs self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values # Calculate ATR high_low = df['high'] - df['low'] high_close = (df['high'] - df['close'].shift()).abs() low_close = (df['low'] - df['close'].shift()).abs() tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values def _update_features(self): """Update technical indicators with new data""" self._initialize_features() # Recalculate all features async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch initial historical data for the environment""" try: logger.info(f"Fetching initial data for {symbol}") # Use the refactored fetch method data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) # Update environment with fetched data if data: self.data = data self._initialize_features() logger.info(f"Initialized environment with {len(data)} candles") else: logger.warning("No initial data received") return len(data) > 0 except Exception as e: logger.error(f"Error fetching initial data: {e}") return False def step(self, action): """Take an action in the environment and return the next state, reward, and done flag""" # Check if we have enough data if self.current_step >= len(self.data) - 1: # We've reached the end of data done = True next_state = self.get_state() info = { 'action': 'none', 'price': self.current_price, 'balance': self.balance, 'position': self.position, 'pnl': self.total_pnl } return next_state, 0, done, info # Store current price before taking action self.current_price = self.data[self.current_step]['close'] # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) reward = self.calculate_reward(action) # Record trade signal for visualization if action > 0: # If not HOLD signal_type = None if action == 1: # BUY/LONG signal_type = 'buy' elif action == 2: # SELL/SHORT signal_type = 'sell' elif action == 3: # CLOSE if self.position == 'long': signal_type = 'close_long' elif self.position == 'short': signal_type = 'close_short' if signal_type: self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.current_price, 'type': signal_type, 'balance': self.balance, 'pnl': self.total_pnl }) # Check for stop loss / take profit hits self.check_sl_tp() # Move to next step self.current_step += 1 done = self.current_step >= len(self.data) - 1 # Get new state next_state = self.get_state() # Create info dictionary info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': self.current_price, 'balance': self.balance, 'position': self.position, 'pnl': self.total_pnl } return next_state, reward, done, info def check_sl_tp(self): """Check if stop loss or take profit has been hit""" if self.position == 'flat': return if self.position == 'long': # Check stop loss if self.current_price <= self.stop_loss: # Stop loss hit pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'stop_loss' }) # Update win/loss count self.loss_count += 1 logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.stop_loss, 'type': 'stop_loss_long', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Check take profit elif self.current_price >= self.take_profit: # Take profit hit pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'take_profit' }) # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.take_profit, 'type': 'take_profit_long', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': # Check stop loss if self.current_price >= self.stop_loss: # Stop loss hit pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'stop_loss' }) # Update win/loss count self.loss_count += 1 logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.stop_loss, 'type': 'stop_loss_short', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Check take profit elif self.current_price <= self.take_profit: # Take profit hit pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, 'market_direction': self.get_market_direction(), 'reason': 'take_profit' }) # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Record signal for visualization self.trade_signals.append({ 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.take_profit, 'type': 'take_profit_short', 'balance': self.balance, 'pnl': self.total_pnl }) # Reset position self.position = 'flat' self.entry_price = 0 self.entry_index = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 def get_state(self): """Create state representation for the agent with enhanced features""" # Ensure we have enough data if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0: # Return zeros if not enough data return np.zeros(STATE_SIZE) # Create a normalized state vector with recent price action and indicators state_components = [] # Safely get the latest price try: latest_price = self.features['price'][-1] except IndexError: # If we can't get the latest price, return zeros return np.zeros(STATE_SIZE) # Safely get price features try: # Price features (normalize recent prices by the latest price) price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0 state_components.append(price_features) except (IndexError, ZeroDivisionError): # If we can't get price features, use zeros state_components.append(np.zeros(10)) # Safely get volume features try: # Volume features (normalize by max volume) max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1 vol_features = np.array(self.features['volume'][-5:]) / max_vol state_components.append(vol_features) except (IndexError, ZeroDivisionError): # If we can't get volume features, use zeros state_components.append(np.zeros(5)) # Technical indicators rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1 state_components.append(rsi) # MACD (normalize) macd_vals = np.array(self.features['macd'][-3:]) macd_signal = np.array(self.features['macd_signal'][-3:]) macd_hist = np.array(self.features['macd_hist'][-3:]) macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5) macd_norm = macd_vals / macd_scale macd_signal_norm = macd_signal / macd_scale macd_hist_norm = macd_hist / macd_scale state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm]) # Bollinger position (where is price relative to bands) bb_upper = np.array(self.features['bollinger_upper'][-3:]) bb_lower = np.array(self.features['bollinger_lower'][-3:]) bb_mid = np.array(self.features['bollinger_mid'][-3:]) price = np.array(self.features['price'][-3:]) # Calculate position of price within Bollinger Bands (0 to 1) bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)] state_components.append(np.array(bb_pos)) # Stochastic oscillator state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0) state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0) # Add predicted prices (if available) if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Normalize predictions relative to current price pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0 state_components.append(pred_norm) else: # Add zeros if no predictions state_components.append(np.zeros(3)) # Add extrema signals (if available) if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0: # Get recent signals idx = len(self.optimal_signals) - 5 if idx < 0: idx = 0 recent_signals = self.optimal_signals[idx:idx+5] # Pad if needed if len(recent_signals) < 5: recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant') state_components.append(recent_signals) else: # Add zeros if no signals state_components.append(np.zeros(5)) # Position info position_info = np.zeros(5) if self.position == 'long': position_info[0] = 1.0 # Position is long position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss % position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance elif self.position == 'short': position_info[0] = -1.0 # Position is short position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss % position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance state_components.append(position_info) # NEW FEATURES START HERE # 1. Price momentum features (rate of change over different periods) if len(self.features['price']) >= 20: roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0 momentum_features = np.array([roc_5, roc_10, roc_20]) state_components.append(momentum_features) else: state_components.append(np.zeros(3)) # 2. Volatility features if len(self.features['price']) >= 20: # Calculate price returns returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] # Calculate volatility (standard deviation of returns) volatility = np.std(returns) # Calculate normalized high-low range high_low_range = np.mean([ (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] for i in range(max(0, len(self.data)-5), len(self.data)) ]) if len(self.data) > 0 else 0 # ATR normalized by price atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 volatility_features = np.array([volatility, high_low_range, atr_norm]) state_components.append(volatility_features) else: state_components.append(np.zeros(3)) # 3. Market regime features if len(self.features['price']) >= 50: # Trend strength (ADX-like measure) ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price trend_strength = abs(ema9 - ema21) / ema21 # Detect if in range or trending is_range_bound = 1.0 if self.is_uncertain_market() else 0.0 is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0 # Detect if near support/resistance near_support = 1.0 if self.is_near_support() else 0.0 near_resistance = 1.0 if self.is_near_resistance() else 0.0 market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance]) state_components.append(market_regime) else: state_components.append(np.zeros(5)) # 4. Trade history features if len(self.trades) > 0: # Recent win/loss ratio recent_trades = self.trades[-min(10, len(self.trades)):] win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades) # Average profit/loss avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0 avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0 # Normalize by balance avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0 avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0 # Last trade result last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0 trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl]) state_components.append(trade_history) else: state_components.append(np.zeros(4)) # Combine all features state = np.concatenate([comp.flatten() for comp in state_components]) # Replace any NaN or infinite values state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) # Ensure the state has the correct size if len(state) != STATE_SIZE: logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}") # Pad or truncate to match expected size if len(state) < STATE_SIZE: state = np.pad(state, (0, STATE_SIZE - len(state))) else: state = state[:STATE_SIZE] return state def get_expanded_state_size(self): """Calculate the size of the expanded state representation""" # Create a dummy state to get its size state = self.get_state() return len(state) async def expand_model_with_new_features(agent, env): """Expand the model to handle new features without retraining from scratch""" # Get the new state size new_state_size = env.get_expanded_state_size() # Only expand if the new state size is larger if new_state_size > agent.state_size: logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})") # Expand the model success = agent.expand_model( new_state_size=new_state_size, new_hidden_size=512, # Increase hidden size for more capacity new_lstm_layers=3, # More layers for deeper patterns new_attention_heads=8 # More attention heads for complex relationships ) if success: logger.info(f"Model successfully expanded to handle {new_state_size} features") return True else: logger.error("Failed to expand model") return False else: logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient") return True def calculate_reward(self, action): """Calculate reward for the given action with improved penalties for losing trades""" reward = 0 # Base reward for actions if action == 0: # HOLD reward = -0.01 # Small penalty for doing nothing elif action == 1: # BUY/LONG if self.position == 'flat': # Opening a long position self.position = 'long' self.entry_price = self.current_price self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100) self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100) # Check if this is an optimal buy point (bottom) current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom else: # Check if we're buying in a downtrend (bad) if self.is_downtrend(): reward -= 0.5 # Penalty for buying in downtrend else: reward += 0.1 # Small reward for opening a position logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif self.position == 'short': # Close short and open long pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = len(self.features['price']) - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'market_direction': self.get_market_direction() }) # Reward based on PnL with stronger penalties for losses if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: # Stronger penalty for losses, scaled by the size of the loss loss_penalty = 1.0 + abs(pnl_dollar) / 5 reward -= loss_penalty self.loss_count += 1 # Extra penalty for closing a losing trade too quickly if trade_duration < 5: reward -= 0.5 # Penalty for very short losing trades logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Now open long self.position = 'long' self.entry_price = self.current_price self.entry_index = len(self.features['price']) - 1 self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100) self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100) # Check if this is an optimal buy point if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif action == 2: # SELL/SHORT if self.position == 'flat': # Opening a short position self.position = 'short' self.entry_price = self.current_price self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100) self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100) # Check if this is an optimal sell point (top) current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: reward += 2.0 # Bonus for selling at a top else: reward += 0.1 # Small reward for opening a position logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif self.position == 'long': # Close long and open short pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Now open short self.position = 'short' self.entry_price = self.current_price self.position_size = self.calculate_position_size() self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100) self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100) # Check if this is an optimal sell point current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: reward += 2.0 # Bonus for selling at a top logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") elif action == 3: # CLOSE if self.position == 'long': # Close long position pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.entry_price = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': # Close short position pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar self.episode_pnl += pnl_dollar # Update max drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance drawdown = (self.peak_balance - self.balance) / self.peak_balance self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) # Reward based on PnL if pnl_dollar > 0: reward += 1.0 + pnl_dollar / 10 # Positive reward for profit self.win_count += 1 else: reward -= 1.0 # Negative reward for loss self.loss_count += 1 logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.entry_price = 0 self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 # Add prediction accuracy component to reward if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Compare the first prediction with actual price if len(self.data) > 1: actual_price = self.data[-1]['close'] predicted_price = self.predicted_prices[0] prediction_error = abs(predicted_price - actual_price) / actual_price # Reward accurate predictions, penalize bad ones if prediction_error < 0.005: # Less than 0.5% error reward += 0.5 elif prediction_error > 0.02: # More than 2% error reward -= 0.5 return reward def is_downtrend(self): """Check if the market is in a downtrend""" if len(self.features['price']) < 20: return False # Use EMA to determine trend short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] # Downtrend if short EMA is below long EMA return short_ema < long_ema def is_uptrend(self): """Check if the market is in an uptrend""" if len(self.features['price']) < 20: return False # Use EMA to determine trend short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] # Uptrend if short EMA is above long EMA return short_ema > long_ema def get_market_direction(self): """Get the current market direction""" if self.is_uptrend(): return "uptrend" elif self.is_downtrend(): return "downtrend" else: return "sideways" def analyze_trades(self): """Analyze completed trades to identify patterns""" if not self.trades: return {} analysis = { 'total_trades': len(self.trades), 'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0), 'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0), 'avg_win': 0, 'avg_loss': 0, 'avg_duration': 0, 'uptrend_win_rate': 0, 'downtrend_win_rate': 0, 'sideways_win_rate': 0 } # Calculate averages wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0] losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0] durations = [t.get('duration', 0) for t in self.trades] analysis['avg_win'] = sum(wins) / len(wins) if wins else 0 analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0 analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0 # Calculate win rates by market direction for direction in ['uptrend', 'downtrend', 'sideways']: direction_trades = [t for t in self.trades if t.get('market_direction') == direction] if direction_trades: wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0) analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100 return analysis def initialize_price_predictor(self, device="cpu"): """Initialize the price prediction model""" self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) self.price_predictor.to(device) self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) self.predicted_prices = np.array([]) def train_price_predictor(self): """Train the price prediction model on recent data""" if len(self.features['price']) < 35: return 0.0 # Get price history price_history = self.features['price'] # Train the model loss = self.price_predictor.train_on_new_data( price_history, self.price_predictor_optimizer, epochs=5 ) return loss def update_price_predictions(self): """Update price predictions""" if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None: self.predicted_prices = np.array([]) return # Get price history price_history = self.features['price'] try: # Get predictions self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) except Exception as e: logger.warning(f"Error updating predictions: {e}") self.predicted_prices = np.array([]) def identify_optimal_trades(self): """Identify optimal entry and exit points based on local extrema""" if len(self.features['price']) < 20: return # Find local bottoms and tops bottoms, tops = find_local_extrema(self.features['price'], window=5) # Store optimal trade points self.optimal_bottoms = bottoms # Buy points self.optimal_tops = tops # Sell points # Create optimal trade signals self.optimal_signals = np.zeros(len(self.features['price'])) for i in bottoms: if 0 <= i < len(self.optimal_signals): # Ensure index is valid self.optimal_signals[i] = 1 # Buy signal for i in tops: if 0 <= i < len(self.optimal_signals): # Ensure index is valid self.optimal_signals[i] = -1 # Sell signal logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points") def calculate_position_size(self): """Calculate position size based on current balance and risk parameters""" # Use a fixed percentage of balance for each trade risk_percent = 5.0 # Risk 5% of balance per trade # Calculate position size with leverage position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE # Apply a safety factor to avoid liquidation safety_factor = 0.8 position_size *= safety_factor # Ensure minimum position size min_position = 10.0 # Minimum position size in USD position_size = max(position_size, min(min_position, self.balance * 0.5)) # Ensure position size doesn't exceed balance * leverage max_position = self.balance * MAX_LEVERAGE position_size = min(position_size, max_position) return position_size def calculate_fees(self, position_size): """Calculate trading fees for a given position size""" # Typical fee rate for crypto exchanges (0.1%) fee_rate = 0.001 # Calculate fee fee = position_size * fee_rate return fee def is_uncertain_market(self): """Check if the market is in an uncertain/sideways state""" if len(self.features['price']) < 20: return True # Check if price is within a narrow range recent_prices = self.features['price'][-20:] price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices) # Check if EMAs are close to each other if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: short_ema = self.features['ema_9'][-1] long_ema = self.features['ema_21'][-1] ema_diff = abs(short_ema - long_ema) / long_ema # Return True if price range is small and EMAs are close return price_range < 0.02 and ema_diff < 0.005 return price_range < 0.015 # Very narrow range def is_near_support(self): """Check if current price is near a support level""" if not hasattr(self, 'features') or len(self.features['price']) < 30: return False # Find recent lows prices = self.features['price'][-30:] lows = [] for i in range(1, len(prices)-1): if prices[i] < prices[i-1] and prices[i] < prices[i+1]: lows.append(prices[i]) if not lows: return False # Check if current price is near any of these lows current_price = self.current_price for low in lows: if abs(current_price - low) / low < 0.01: # Within 1% of a recent low return True return False def is_near_resistance(self): """Check if current price is near a resistance level""" if not hasattr(self, 'features') or len(self.features['price']) < 30: return False # Find recent highs prices = self.features['price'][-30:] highs = [] for i in range(1, len(prices)-1): if prices[i] > prices[i-1] and prices[i] > prices[i+1]: highs.append(prices[i]) if not highs: return False # Check if current price is near any of these highs current_price = self.current_price for high in highs: if abs(current_price - high) / high < 0.01: # Within 1% of a recent high return True return False def is_market_turning(self): """Check if the market is potentially changing direction""" if len(self.features['price']) < 20: return False # Check for divergence between price and momentum indicators if len(self.features['rsi']) > 5: # Price making higher highs but RSI making lower highs (bearish divergence) price_trend = self.features['price'][-1] > self.features['price'][-5] rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5] if price_trend != rsi_trend: return True # Check for EMA crossover if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1: short_ema_prev = self.features['ema_9'][-2] long_ema_prev = self.features['ema_21'][-2] short_ema_curr = self.features['ema_9'][-1] long_ema_curr = self.features['ema_21'][-1] # Check if EMAs just crossed if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \ (short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr): return True return False def is_market_against_position(self, position_type): """Check if market conditions have turned against the current position""" if position_type == 'long': # For long positions, check if market has turned bearish return self.is_downtrend() and not self.is_near_support() elif position_type == 'short': # For short positions, check if market has turned bullish return self.is_uptrend() and not self.is_near_resistance() return False def is_near_optimal_exit(self, position_type): """Check if current price is near an optimal exit point for the position""" current_idx = len(self.features['price']) - 1 if position_type == 'long' and hasattr(self, 'optimal_tops'): # For long positions, optimal exit is near tops for top_idx in self.optimal_tops: if abs(current_idx - top_idx) < 3: # Within 3 candles of a top return True elif position_type == 'short' and hasattr(self, 'optimal_bottoms'): # For short positions, optimal exit is near bottoms for bottom_idx in self.optimal_bottoms: if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom return True return False def calculate_future_profit_potential(self, position_type, lookahead=20): """ Calculate potential profit if position is held for a certain period This is used for retrospective backtesting rewards Args: position_type: 'long' or 'short' lookahead: Number of candles to look ahead Returns: Potential profit percentage """ if len(self.data) <= 1 or self.current_step >= len(self.data): return 0 # Get current price current_price = self.current_price # Get future prices (if available in historical data) future_prices = [] current_idx = self.current_step # Safely get future prices for i in range(1, min(lookahead + 1, len(self.data) - current_idx)): if current_idx + i < len(self.data): future_prices.append(self.data[current_idx + i]['close']) if not future_prices: return 0 # Calculate potential profit if position_type == 'long': # For long positions, find the maximum price in the future max_future_price = max(future_prices) potential_profit = (max_future_price - current_price) / current_price * 100 else: # short # For short positions, find the minimum price in the future min_future_price = min(future_prices) potential_profit = (current_price - min_future_price) / current_price * 100 return potential_profit async def initialize_futures(self, exchange): """Initialize futures trading parameters""" if not self.demo: try: # Set up futures trading parameters await exchange.set_position_mode(True) # Hedge mode await exchange.set_margin_mode("cross", symbol=self.futures_symbol) await exchange.set_leverage(self.leverage, symbol=self.futures_symbol) logger.info(f"Futures initialized with {self.leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True async def execute_real_trade(self, exchange, action, current_price): """Execute real futures trade on MEXC""" try: position_size = self.calculate_position_size() if action == 1: # Open long order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='buy', amount=position_size, params={'positionSide': 'LONG'} ) logger.info(f"Opened LONG position: {order}") elif action == 2: # Open short order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='sell', amount=position_size, params={'positionSide': 'SHORT'} ) logger.info(f"Opened SHORT position: {order}") elif action == 3: # Close position position_side = 'LONG' if self.position == 'long' else 'SHORT' order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='sell' if position_side == 'LONG' else 'buy', amount=self.position_size, params={'positionSide': position_side} ) logger.info(f"Closed {position_side} position: {order}") return order except Exception as e: logger.error(f"Trade execution failed: {e}") return None # Ensure GPU usage if available def get_device(): """Get the best available device (CUDA GPU or CPU)""" if torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") # Set up for mixed precision training torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") logger.info("GPU not available, using CPU") return device # Update Agent class to use GPU properly class Agent: def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None): """Initialize Agent with architecture parameters stored as attributes""" self.state_size = state_size self.action_size = action_size self.hidden_size = hidden_size # Store hidden_size as an instance attribute self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute self.attention_heads = attention_heads # Store attention_heads as an instance attribute # Set device self.device = device if device is not None else get_device() # Initialize networks self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) # Initialize optimizer self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) # Initialize replay memory self.memory = ReplayMemory(MEMORY_SIZE) # Initialize exploration parameters self.epsilon = EPSILON_START self.epsilon_decay = EPSILON_DECAY self.epsilon_min = EPSILON_END # Initialize step counter self.steps_done = 0 # Initialize TensorBoard writer self.writer = None # Initialize GradScaler for mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None # Rest of the initialization code... def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): """Expand the model to handle more features or increase capacity""" logger.info(f"Expanding model: {self.state_size} → {new_state_size}, " f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}") # Save old weights old_state_dict = self.policy_net.state_dict() # Create new larger networks new_policy_net = DQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) new_target_net = DQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) # Transfer weights for common layers new_state_dict = new_policy_net.state_dict() for name, param in old_state_dict.items(): if name in new_state_dict: # If shapes match, copy directly if new_state_dict[name].shape == param.shape: new_state_dict[name] = param # For first layer, copy weights for the original input dimensions elif name == "fc1.weight": new_state_dict[name][:, :self.state_size] = param # For other layers, initialize with a strategy that preserves scale else: logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}") # Load transferred weights new_policy_net.load_state_dict(new_state_dict) new_target_net.load_state_dict(new_state_dict) # Replace networks self.policy_net = new_policy_net self.target_net = new_target_net self.target_net.eval() # Update optimizer self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) # Update state size self.state_size = new_state_size # Print new model size total_params = sum(p.numel() for p in self.policy_net.parameters()) logger.info(f"New model size: {total_params:,} parameters") return True def select_action(self, state, training=True): sample = random.random() if training: # Epsilon decay self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \ np.exp(-1. * self.steps_done / EPSILON_DECAY) self.steps_done += 1 if sample > self.epsilon or not training: with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) action_values = self.policy_net(state_tensor) return action_values.max(1)[1].item() else: return random.randrange(self.action_size) def learn(self): """Learn from a batch of experiences""" if len(self.memory) < BATCH_SIZE: return None try: # Sample a batch of experiences experiences = self.memory.sample(BATCH_SIZE) # Convert experiences to tensors states = torch.FloatTensor([e.state for e in experiences]).to(self.device) actions = torch.LongTensor([e.action for e in experiences]).to(self.device) rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device) next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device) dones = torch.FloatTensor([e.done for e in experiences]).to(self.device) # Use mixed precision for forward/backward passes if self.device.type == "cuda" and self.scaler is not None: with torch.amp.autocast('cuda'): # Compute Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) # Compute next Q values with target network with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) # Reshape target values to match current_q_values target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) # Backward pass with mixed precision self.optimizer.zero_grad() self.scaler.scale(loss).backward() # Gradient clipping to prevent exploding gradients self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.scaler.step(self.optimizer) self.scaler.update() else: # Standard precision for CPU # Compute Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) # Compute next Q values with target network with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) # Reshape target values to match current_q_values target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping to prevent exploding gradients torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() # Update steps done self.steps_done += 1 # Update target network if self.steps_done % TARGET_UPDATE == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) return loss.item() except Exception as e: logger.error(f"Error during learning: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path="models/trading_agent_best_pnl.pt"): """Save the model in a format compatible with PyTorch 2.6+""" try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(path), exist_ok=True) # Ensure architecture parameters are set if not hasattr(self, 'hidden_size'): self.hidden_size = 256 # Default value logger.warning("Setting default hidden_size=256 for saving") if not hasattr(self, 'lstm_layers'): self.lstm_layers = 2 # Default value logger.warning("Setting default lstm_layers=2 for saving") if not hasattr(self, 'attention_heads'): self.attention_heads = 4 # Default value logger.warning("Setting default attention_heads=4 for saving") # Save model state checkpoint = { 'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epsilon': self.epsilon, 'state_size': self.state_size, 'action_size': self.action_size, 'hidden_size': self.hidden_size, 'lstm_layers': self.lstm_layers, 'attention_heads': self.attention_heads } # Save scaler state if it exists if hasattr(self, 'scaler') and self.scaler is not None: checkpoint['scaler'] = self.scaler.state_dict() # Save with pickle_protocol=4 for better compatibility torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4) logger.info(f"Model saved to {path}") except Exception as e: logger.error(f"Error saving model: {e}") import traceback logger.error(traceback.format_exc()) def load(self, path="models/trading_agent_best_pnl.pt"): """Load a trained model with improved error handling for PyTorch 2.6 compatibility""" try: # First try to load with weights_only=False (for models saved with older PyTorch versions) try: logger.info(f"Attempting to load model with weights_only=False: {path}") checkpoint = torch.load(path, map_location=self.device, weights_only=False) logger.info("Model loaded successfully with weights_only=False") except Exception as e1: logger.warning(f"Failed to load with weights_only=False: {e1}") # Try with safe_globals context manager try: logger.info("Attempting to load with safe_globals context manager") import numpy as np from torch.serialization import safe_globals # Add numpy scalar to safe globals with safe_globals(['numpy._core.multiarray.scalar']): checkpoint = torch.load(path, map_location=self.device) logger.info("Model loaded successfully with safe_globals") except Exception as e2: logger.warning(f"Failed to load with safe_globals: {e2}") # Last resort: try with pickle_module=pickle logger.info("Attempting to load with pickle_module") import pickle checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False) logger.info("Model loaded successfully with pickle_module") # Load state dictionaries self.policy_net.load_state_dict(checkpoint['policy_net']) self.target_net.load_state_dict(checkpoint['target_net']) # Try to load optimizer state try: self.optimizer.load_state_dict(checkpoint['optimizer']) except Exception as e: logger.warning(f"Could not load optimizer state: {e}") # Load epsilon if available if 'epsilon' in checkpoint: self.epsilon = checkpoint['epsilon'] # Load architecture parameters if available if 'state_size' in checkpoint: self.state_size = checkpoint['state_size'] if 'action_size' in checkpoint: self.action_size = checkpoint['action_size'] if 'hidden_size' in checkpoint: self.hidden_size = checkpoint['hidden_size'] else: # If hidden_size not in checkpoint, infer from model try: self.hidden_size = self.policy_net.fc1.weight.shape[0] logger.info(f"Inferred hidden_size={self.hidden_size} from model") except: self.hidden_size = 256 # Default value logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}") if 'lstm_layers' in checkpoint: self.lstm_layers = checkpoint['lstm_layers'] else: self.lstm_layers = 2 # Default value if 'attention_heads' in checkpoint: self.attention_heads = checkpoint['attention_heads'] else: self.attention_heads = 4 # Default value logger.info(f"Model loaded successfully from {path}") except Exception as e: logger.error(f"Error loading model: {e}") import traceback logger.error(traceback.format_exc()) raise def add_chart_to_tensorboard(self, env, global_step): """Add trading chart to TensorBoard""" try: if len(env.data) < 10: return # Create chart image chart_img = create_candlestick_figure( env.data, env.trade_signals, window_size=100, title=f"Trading Chart - Step {global_step}" ) if chart_img is not None: # Convert PIL image to numpy array for TensorBoard chart_array = np.array(chart_img) # TensorBoard expects [C, H, W] format chart_array = np.transpose(chart_array, (2, 0, 1)) self.writer.add_image('Trading Chart', chart_array, global_step) # Add position information as text entry_price = env.entry_price if env.entry_price else 0.00 position_info = f""" **Current Position**: {env.position.upper()} **Entry Price**: ${entry_price:.2f} **Current Price**: ${env.data[-1]['close']:.2f} **Position Size**: ${env.position_size:.2f} **Unrealized PnL**: ${env.total_pnl:.2f} """ self.writer.add_text('Position', position_info, global_step) except Exception as e: logger.error(f"Error adding chart to TensorBoard: {str(e)}") # Continue without visualization rather than crashing async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" # Connect to MEXC websocket uri = "wss://stream.mexc.com/ws" async with websockets.connect(uri) as websocket: # Subscribe to kline data subscribe_msg = { "method": "SUBSCRIPTION", "params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"] } await websocket.send(json.dumps(subscribe_msg)) logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines") while True: try: response = await websocket.recv() data = json.loads(response) if 'data' in data: kline = data['data'] candle = { 'timestamp': kline['t'], 'open': float(kline['o']), 'high': float(kline['h']), 'low': float(kline['l']), 'close': float(kline['c']), 'volume': float(kline['v']) } yield candle except Exception as e: logger.error(f"Websocket error: {e}") # Try to reconnect await asyncio.sleep(5) break async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000): """Train the agent using historical and live data with GPU acceleration""" # Initialize statistics tracking stats = { 'episode_rewards': [], 'episode_lengths': [], 'balances': [], 'win_rates': [], 'episode_pnls': [], 'cumulative_pnl': [], 'drawdowns': [], 'prediction_accuracy': [], 'trade_analysis': [] } # Track best models best_reward = float('-inf') best_pnl = float('-inf') # Initialize TensorBoard writer if not already initialized if not hasattr(agent, 'writer') or agent.writer is None: agent.writer = SummaryWriter('runs/training') # Training loop for episode in range(num_episodes): try: # Reset environment state = env.reset() episode_reward = 0 prediction_loss = 0 # Episode loop for step in range(max_steps_per_episode): # Select action action = agent.select_action(state) # Take action try: next_state, reward, done, info = env.step(action) except Exception as e: logger.error(f"Error in step function: {e}") break # Store transition in replay memory agent.memory.push(state, action, reward, next_state, done) # Move to the next state state = next_state # Update episode reward episode_reward += reward # Learn from experience if len(agent.memory) > BATCH_SIZE: agent.learn() # Update price predictions periodically if step % 50 == 0: try: env.update_price_predictions() env.identify_optimal_trades() except Exception as e: logger.warning(f"Error updating predictions: {e}") # Add chart to TensorBoard periodically if step % 50 == 0 or (step == max_steps_per_episode - 1) or done: try: global_step = episode * max_steps_per_episode + step agent.add_chart_to_tensorboard(env, global_step) except Exception as e: logger.warning(f"Error adding chart to TensorBoard: {e}") # End episode if done if done: break # Update target network periodically if episode % TARGET_UPDATE == 0: agent.update_target_network() # Calculate win rate total_trades = env.win_count + env.loss_count win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0 # Train price predictor try: if episode % 5 == 0 and len(env.data) > 50: prediction_loss = env.train_price_predictor() except Exception as e: logger.warning(f"Error training price predictor: {e}") prediction_loss = 0 # Analyze trades try: trade_analysis = env.analyze_trades() stats['trade_analysis'].append(trade_analysis) except Exception as e: logger.warning(f"Error analyzing trades: {e}") trade_analysis = {} stats['trade_analysis'].append({}) # Calculate prediction accuracy prediction_accuracy = 0.0 try: if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0: if len(env.data) > 5: actual_prices = [candle['close'] for candle in env.data[-5:]] predicted = env.predicted_prices[:min(5, len(actual_prices))] errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])] prediction_accuracy = 100 * (1 - sum(errors) / len(errors)) except Exception as e: logger.warning(f"Error calculating prediction accuracy: {e}") # Log statistics stats['episode_rewards'].append(episode_reward) stats['episode_lengths'].append(step + 1) stats['balances'].append(env.balance) stats['win_rates'].append(win_rate) stats['episode_pnls'].append(env.episode_pnl) stats['cumulative_pnl'].append(env.total_pnl) stats['drawdowns'].append(env.max_drawdown * 100) stats['prediction_accuracy'].append(prediction_accuracy) # Log detailed trade analysis if trade_analysis: logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, " f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | " f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}") # Log to TensorBoard agent.writer.add_scalar('Reward/train', episode_reward, episode) agent.writer.add_scalar('Balance/train', env.balance, episode) agent.writer.add_scalar('WinRate/train', win_rate, episode) agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode) agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode) agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) agent.writer.add_scalar('PredictionLoss', prediction_loss, episode) agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) # Add final chart for this episode try: agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode) except Exception as e: logger.warning(f"Error adding final chart: {e}") logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, " f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, " f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, " f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%") # Save best model by reward if episode_reward > best_reward: best_reward = episode_reward agent.save("models/trading_agent_best_reward.pt") # Save best model by PnL if env.episode_pnl > best_pnl: best_pnl = env.episode_pnl agent.save("models/trading_agent_best_pnl.pt") # Save checkpoint if episode % 10 == 0: agent.save(f"models/trading_agent_episode_{episode}.pt") except Exception as e: logger.error(f"Error in episode {episode}: {e}") continue # Save final model agent.save("models/trading_agent_final.pt") # Plot training results plot_training_results(stats) return stats def plot_training_results(stats): """Plot detailed training results""" plt.figure(figsize=(20, 15)) # Plot rewards plt.subplot(3, 2, 1) plt.plot(stats['episode_rewards']) plt.title('Episode Rewards') plt.xlabel('Episode') plt.ylabel('Reward') # Plot balance plt.subplot(3, 2, 2) plt.plot(stats['balances']) plt.title('Account Balance') plt.xlabel('Episode') plt.ylabel('Balance ($)') # Plot win rate plt.subplot(3, 2, 3) plt.plot(stats['win_rates']) plt.title('Win Rate') plt.xlabel('Episode') plt.ylabel('Win Rate (%)') # Plot episode PnL plt.subplot(3, 2, 4) plt.plot(stats['episode_pnls']) plt.title('Episode PnL') plt.xlabel('Episode') plt.ylabel('PnL ($)') # Plot cumulative PnL plt.subplot(3, 2, 5) plt.plot(stats['cumulative_pnl']) plt.title('Cumulative PnL') plt.xlabel('Episode') plt.ylabel('Cumulative PnL ($)') # Plot drawdown plt.subplot(3, 2, 6) plt.plot(stats['drawdowns']) plt.title('Maximum Drawdown') plt.xlabel('Episode') plt.ylabel('Drawdown (%)') plt.tight_layout() plt.savefig('training_results.png') # Save statistics to CSV df = pd.DataFrame(stats) df.to_csv('training_stats.csv', index=False) logger.info("Training statistics saved to training_stats.csv and training_results.png") def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" total_reward = 0 total_profit = 0 total_trades = 0 winning_trades = 0 for episode in range(num_episodes): state = env.reset() episode_reward = 0 initial_balance = env.balance done = False while not done: # Select action (no exploration) action = agent.select_action(state, training=False) next_state, reward, done, info = env.step(action) state = next_state episode_reward += reward total_reward += episode_reward total_profit += env.balance - initial_balance # Count trades and wins for trade in env.trades: if 'pnl_percent' in trade: total_trades += 1 if trade['pnl_percent'] > 0: winning_trades += 1 # Calculate averages avg_reward = total_reward / num_episodes avg_profit = total_profit / num_episodes win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0 logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, " f"Win Rate={win_rate:.1f}%") return avg_reward, avg_profit, win_rate async def test_training(): """Test the training process with a small number of episodes""" logger.info("Starting training tests...") # Initialize exchange exchange = ccxt.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True, }) try: # Create environment with small initial balance for testing env = TradingEnvironment( exchange=exchange, symbol="ETH/USDT", timeframe="1m", leverage=MAX_LEVERAGE, initial_balance=100, # Small balance for testing demo=True # Always use demo mode for testing ) # Fetch initial data await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) # Create agent agent = Agent(state_size=STATE_SIZE, action_size=env.action_space) # Run a few test episodes test_episodes = 3 logger.info(f"Running {test_episodes} test episodes...") for episode in range(test_episodes): state = env.reset() episode_reward = 0 done = False step = 0 while not done and step < 100: # Limit steps for testing # Select action action = agent.select_action(state) # Take action next_state, reward, done, info = env.step(action) # Store experience agent.memory.push(state, action, reward, next_state, done) # Learn loss = agent.learn() state = next_state episode_reward += reward step += 1 # Print progress if step % 10 == 0: logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}") logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}") # Test model saving try: agent.save("models/test_model.pt") logger.info("Successfully saved model") except Exception as e: logger.error(f"Error saving model: {e}") logger.info("Training tests completed successfully") return True except Exception as e: logger.error(f"Training test failed: {e}") return False finally: await exchange.close() async def initialize_exchange(): """Initialize the exchange connection""" try: # Try to initialize with async support first try: exchange = ccxt.pro.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True }) logger.info(f"Exchange initialized with async support: {exchange.id}") except (AttributeError, ImportError): # Fall back to standard CCXT exchange = ccxt.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True }) logger.info(f"Exchange initialized with standard CCXT: {exchange.id}") return exchange except Exception as e: logger.error(f"Failed to initialize exchange: {e}") raise async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch historical OHLCV data from the exchange""" try: logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}") # Use the refactored fetch method data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) if not data: logger.warning("No historical data received") return data except Exception as e: logger.error(f"Failed to fetch historical data: {e}") return [] async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): """Run the trading bot in live mode with enhanced error handling""" logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe") logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") # Verify agent is properly initialized try: # Ensure agent has all required attributes if not hasattr(agent, 'hidden_size'): agent.hidden_size = 256 # Default value logger.warning("Agent missing hidden_size attribute, using default: 256") if not hasattr(agent, 'lstm_layers'): agent.lstm_layers = 2 # Default value logger.warning("Agent missing lstm_layers attribute, using default: 2") if not hasattr(agent, 'attention_heads'): agent.attention_heads = 4 # Default value logger.warning("Agent missing attention_heads attribute, using default: 4") logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}") except Exception as e: logger.error(f"Error checking agent configuration: {e}") # Continue anyway, as these are just informational attributes if not demo: # Confirm with user before starting live trading confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") if confirmation != "CONFIRM": logger.info("Live trading canceled by user") return # Initialize futures trading if not in demo mode try: await env.initialize_futures(exchange) logger.info(f"Futures trading initialized with {leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True # Initialize TensorBoard for monitoring if not hasattr(agent, 'writer') or agent.writer is None: from torch.utils.tensorboard import SummaryWriter # Fix the datetime usage here current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}') # Track performance metrics trades_count = 0 winning_trades = 0 total_profit = 0 max_drawdown = 0 peak_balance = env.balance step_counter = 0 prev_position = 'flat' # Create directory for trade logs os.makedirs('trade_logs', exist_ok=True) # Fix the datetime usage here current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") trade_log_path = f'trade_logs/trades_{current_time}.csv' with open(trade_log_path, 'w') as f: f.write("timestamp,action,price,position_size,balance,pnl\n") logger.info("Entering live trading loop...") try: while True: try: # Fetch latest candle data candle = await get_latest_candle(exchange, symbol) if candle is None: logger.warning("Failed to fetch latest candle, retrying in 5 seconds...") await asyncio.sleep(5) continue # Add new data to environment env.add_data(candle) # Get current state and select action state = env.get_state() # Verify state shape matches agent's expected input if state.shape[0] != agent.state_size: logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}") # Pad or truncate state to match expected size if state.shape[0] < agent.state_size: state = np.pad(state, (0, agent.state_size - state.shape[0])) else: state = state[:agent.state_size] action = agent.select_action(state, training=False) # Ensure action is valid if action >= agent.action_size: logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}") action = agent.action_size - 1 # Log action action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}") # Execute action if not demo: # Execute real trade on exchange current_price = env.data[-1]['close'] trade_result = await env.execute_real_trade(exchange, action, current_price) if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False): error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed' logger.error(f"Trade execution failed: {error_msg}") # Continue with simulated trade for tracking purposes # Update environment with action (simulated in demo mode) try: next_state, reward, done, info = env.step(action) except ValueError as e: # Handle case where step returns 3 values instead of 4 if "not enough values to unpack" in str(e): logger.warning("Step function returned 3 values instead of 4, creating info dict") next_state, reward, done = env.step(action) info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': env.current_price, 'balance': env.balance, 'position': env.position, 'pnl': env.total_pnl } else: raise # Log trade if position changed if env.position != prev_position: trades_count += 1 if env.last_trade_profit > 0: winning_trades += 1 total_profit += env.last_trade_profit # Log trade details with open(trade_log_path, 'a') as f: f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n") logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") # Update performance metrics if env.balance > peak_balance: peak_balance = env.balance current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 if current_drawdown > max_drawdown: max_drawdown = current_drawdown # Update TensorBoard metrics step_counter += 1 agent.writer.add_scalar('Live/Balance', env.balance, step_counter) agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) # Update chart visualization if step_counter % 5 == 0 or env.position != prev_position: agent.add_chart_to_tensorboard(env, step_counter) # Log performance summary if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) performance_text = f""" **Live Trading Performance** Balance: ${env.balance:.2f} Total PnL: ${env.total_pnl:.2f} Trades: {trades_count} Win Rate: {win_rate:.1f}% Max Drawdown: {max_drawdown*100:.1f}% """ agent.writer.add_text('Performance', performance_text, step_counter) prev_position = env.position # Wait for next candle logger.info(f"Waiting for next candle... (Step {step_counter})") await asyncio.sleep(10) # Check every 10 seconds except Exception as e: logger.error(f"Error in live trading loop: {str(e)}") import traceback logger.error(traceback.format_exc()) logger.info("Continuing after error...") await asyncio.sleep(30) # Wait longer after an error except KeyboardInterrupt: logger.info("Live trading stopped by user") # Final performance report if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 logger.info(f"Trading session summary:") logger.info(f"Total trades: {trades_count}") logger.info(f"Win rate: {win_rate:.1f}%") logger.info(f"Final balance: ${env.balance:.2f}") logger.info(f"Total profit: ${total_profit:.2f}") logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%") logger.info(f"Trade log saved to: {trade_log_path}") async def get_latest_candle(exchange, symbol): """Get the latest candle data""" try: # Use the refactored fetch method with limit=1 data = await fetch_ohlcv_data(exchange, symbol, "1m", 1) if data and len(data) > 0: return data[0] else: logger.warning("No candle data received") return None except Exception as e: logger.error(f"Failed to fetch latest candle: {e}") return None async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): """Fetch OHLCV data with proper handling for both async and standard CCXT""" try: # Check if exchange has fetchOHLCV method if not hasattr(exchange, 'fetchOHLCV'): logger.error("Exchange does not support OHLCV data fetching") return [] # Handle different CCXT versions if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): # Use async method if available ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit) else: # Use synchronous method with run_in_executor loop = asyncio.get_event_loop() ohlcv = await loop.run_in_executor( None, lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) ) # Convert to list of dictionaries data = [] for candle in ohlcv: timestamp, open_price, high, low, close, volume = candle data.append({ 'timestamp': timestamp, 'open': open_price, 'high': high, 'low': low, 'close': close, 'volume': volume }) logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})") return data except Exception as e: logger.error(f"Failed to fetch OHLCV data: {e}") return [] async def initialize_websocket_data_stream(symbol="ETH/USDT", timeframe="1m"): """Initialize a WebSocket connection for real-time trading data Args: symbol: Trading pair symbol (e.g., "ETH/USDT") timeframe: Timeframe for candle aggregation (e.g., "1m") Returns: Tuple of (websocket, candle_data) where websocket is the BinanceWebSocket instance and candle_data is a dict to track ongoing candle formation """ try: # Initialize historical data handler to get initial data historical_data = BinanceHistoricalData() # Convert timeframe to seconds for historical data if timeframe == "1m": interval_seconds = 60 elif timeframe == "5m": interval_seconds = 300 elif timeframe == "15m": interval_seconds = 900 elif timeframe == "1h": interval_seconds = 3600 else: interval_seconds = 60 # Default to 1m # Fetch initial historical data initial_data = historical_data.get_historical_candles( symbol=symbol, interval_seconds=interval_seconds, limit=1000 # Get 1000 candles for good history ) # Convert pandas DataFrame to list of dictionaries for our environment initial_candles = [] if not initial_data.empty: for _, row in initial_data.iterrows(): candle = { 'timestamp': int(row['timestamp'].timestamp() * 1000), 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) } initial_candles.append(candle) logger.info(f"Loaded {len(initial_candles)} historical candles") else: logger.warning("No historical data fetched") # Initialize WebSocket for real-time data binance_ws = BinanceWebSocket(symbol.replace('/', '')) await binance_ws.connect() # Track the current candle data current_minute = None current_candle = None logger.info(f"WebSocket for {symbol} initialized successfully") return binance_ws, initial_candles except Exception as e: logger.error(f"Failed to initialize WebSocket data stream: {e}") logger.error(traceback.format_exc()) return None, [] async def process_websocket_ticks(websocket, env, agent=None, demo=True, timeframe="1m"): """Process real-time ticks from WebSocket and aggregate them into candles Args: websocket: BinanceWebSocket instance env: TradingEnvironment instance agent: Agent instance (optional, for live trading) demo: Whether to run in demo mode timeframe: Timeframe for candle aggregation """ # Initialize variables for candle aggregation current_candle = None current_minute = None trades_count = 0 step_counter = 0 try: logger.info("Starting WebSocket tick processing...") while websocket.running: # Get the next tick from WebSocket tick = await websocket.receive() if tick is None: # No data received, wait and try again await asyncio.sleep(0.1) continue # Extract data from tick timestamp = tick.get('timestamp') price = tick.get('price') volume = tick.get('volume') if timestamp is None or price is None: logger.warning(f"Invalid tick data received: {tick}") continue # Convert timestamp to datetime tick_time = datetime.fromtimestamp(timestamp / 1000) # For 1-minute candles, track the minute if timeframe == "1m": tick_minute = tick_time.replace(second=0, microsecond=0) # If this is a new minute, close the current candle and start a new one if current_minute is None or tick_minute > current_minute: # If there was a previous candle, add it to the environment if current_candle is not None: # Add the candle to the environment env.add_data(current_candle) # Process trading decisions if agent is provided if agent is not None: state = env.get_state() action = agent.select_action(state, training=False) # Execute action in environment next_state, reward, done, info = env.step(action) # Log trading activity action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" logger.info(f"Step {step_counter}: Action {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}") step_counter += 1 # Start a new candle current_minute = tick_minute current_candle = { 'timestamp': int(current_minute.timestamp() * 1000), 'open': price, 'high': price, 'low': price, 'close': price, 'volume': volume } logger.debug(f"Started new candle at {current_minute}") else: # Update the current candle current_candle['high'] = max(current_candle['high'], price) current_candle['low'] = min(current_candle['low'], price) current_candle['close'] = price current_candle['volume'] += volume # For other timeframes, implement similar logic # ... except asyncio.CancelledError: logger.info("WebSocket processing canceled") except Exception as e: logger.error(f"Error in WebSocket tick processing: {e}") logger.error(traceback.format_exc()) finally: # Make sure to close the WebSocket if websocket: await websocket.close() logger.info("WebSocket connection closed") # Add this near the top of the file, after imports def ensure_pytorch_compatibility(): """Ensure compatibility with PyTorch 2.6+ for model loading""" try: import torch from torch.serialization import add_safe_globals import numpy as np # Add numpy scalar to safe globals for PyTorch 2.6+ add_safe_globals(['numpy._core.multiarray.scalar']) logger.info("Added numpy scalar to PyTorch safe globals") except (ImportError, AttributeError) as e: logger.warning(f"Could not configure PyTorch compatibility: {e}") logger.warning("This might cause issues with model loading in PyTorch 2.6+") # Call this function at the start of the main function async def main(): # Ensure PyTorch compatibility ensure_pytorch_compatibility() parser = argparse.ArgumentParser(description='Trading Bot') parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train', help='Operation mode: train, eval, or live') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training or evaluation') parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true', help='Run in demo mode (paper trading) if true') parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe (1m, 5m, 15m, 1h, etc.)') parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading') parser.add_argument('--model', type=str, default=None, help='Path to model file for evaluation or live trading') parser.add_argument('--use-websocket', action='store_true', help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)') args = parser.parse_args() # Convert string boolean to actual boolean demo_mode = args.demo.lower() == 'true' # Get device (GPU or CPU) device = get_device() exchange = None try: # Initialize exchange exchange = await initialize_exchange() # Create environment env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode) if args.mode == 'train': # Fetch initial data for training await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) # Create agent with consistent parameters # Note: Using STATE_SIZE and action_size=4 for consistency agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) # Train the agent logger.info(f"Starting training for {args.episodes} episodes...") stats = await train_agent(agent, env, num_episodes=args.episodes) elif args.mode == 'eval' or args.mode == 'live': # Fetch initial data for the specified symbol and timeframe await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000) # Determine model path model_path = args.model if args.model else "models/trading_agent_best_pnl.pt" if not os.path.exists(model_path): logger.error(f"Model file not found: {model_path}") return # Create agent with default parameters agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) # Try to load the model try: # Add numpy scalar to safe globals before loading import numpy as np from torch.serialization import add_safe_globals # Add numpy scalar to safe globals add_safe_globals(['numpy._core.multiarray.scalar']) # Load the model agent.load(model_path) logger.info(f"Model loaded successfully from {model_path}") except Exception as e: logger.error(f"Failed to load model: {e}") # Ask user if they want to continue with a new model if args.mode == 'live': confirmation = input("Failed to load model. Continue with a new model? (y/n): ") if confirmation.lower() != 'y': logger.info("Live trading canceled by user") return logger.info("Continuing with a new model") else: logger.info("Continuing evaluation with a new model") if args.mode == 'eval': # Evaluate the agent logger.info("Evaluating agent...") avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes) elif args.mode == 'live': # Start live trading logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe") logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x") if args.use_websocket: logger.info("Using Binance WebSocket for real-time data") await live_trading_with_websocket( agent=agent, env=env, symbol=args.symbol, timeframe=args.timeframe, demo=demo_mode, leverage=args.leverage ) else: logger.info("Using CCXT for real-time data") await live_trading( agent=agent, env=env, exchange=exchange, symbol=args.symbol, timeframe=args.timeframe, demo=demo_mode, leverage=args.leverage ) except Exception as e: logger.error(f"Error in main function: {e}") import traceback logger.error(traceback.format_exc()) finally: # Clean up exchange connection if exchange: try: if hasattr(exchange, 'close'): await exchange.close() elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'): await exchange.client.close() logger.info("Exchange connection closed") except Exception as e: logger.warning(f"Could not properly close exchange connection: {e}") # Add this function near the top with other utility functions def create_candlestick_figure(data, trade_signals, window_size=100, title=""): """Create a candlestick chart with trade signals for TensorBoard visualization""" if len(data) < 10: return None try: # Create figure fig = plt.figure(figsize=(12, 8)) # Prepare data for plotting df = pd.DataFrame(data[-window_size:]) df['date'] = pd.to_datetime(df['timestamp'], unit='ms') df.set_index('date', inplace=True) # Create subplot grid gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) price_ax = plt.subplot(gs[0]) volume_ax = plt.subplot(gs[1], sharex=price_ax) # Plot candlesticks - use a simpler approach if mplfinance fails try: # Use a different style or approach that doesn't use 'type' parameter mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo') except Exception as e: logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot") # Fallback to simple plot price_ax.plot(df.index, df['close'], label='Price') volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5) # Add trade signals for signal in trade_signals: try: timestamp = pd.to_datetime(signal['timestamp'], unit='ms') price = signal['price'] if signal['type'] == 'buy': price_ax.plot(timestamp, price, '^', color='green', markersize=10) elif signal['type'] == 'sell': price_ax.plot(timestamp, price, 'v', color='red', markersize=10) elif signal['type'] == 'close_long': price_ax.plot(timestamp, price, 'x', color='gold', markersize=10) elif signal['type'] == 'close_short': price_ax.plot(timestamp, price, 'x', color='black', markersize=10) elif 'stop_loss' in signal['type']: price_ax.plot(timestamp, price, 'X', color='purple', markersize=10) elif 'take_profit' in signal['type']: price_ax.plot(timestamp, price, '*', color='cyan', markersize=10) except Exception as e: logger.warning(f"Error plotting signal: {e}") continue # Add balance and PnL annotation if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]: balance = trade_signals[-1]['balance'] pnl = trade_signals[-1]['pnl'] price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}", xy=(0.02, 0.95), xycoords='axes fraction', bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)) # Set title and format price_ax.set_title(title) fig.tight_layout() # Convert to image buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) plt.close(fig) img = Image.open(buf) return img except Exception as e: logger.error(f"Error creating chart: {str(e)}") return None async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): """Run the trading bot in live mode using Binance WebSocket for real-time data Args: agent: The trading agent to use for decision making env: The trading environment symbol: The trading pair symbol (e.g., "ETH/USDT") timeframe: The candlestick timeframe (e.g., "1m") demo: Whether to run in demo mode (paper trading) leverage: The leverage to use for trading Returns: None """ logger.info(f"Starting live trading with WebSocket for {symbol} on {timeframe} timeframe") logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") # If not demo mode, confirm with user before starting live trading if not demo: confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") if confirmation != "CONFIRM": logger.info("Live trading canceled by user") return # Initialize TensorBoard for monitoring if not hasattr(agent, 'writer') or agent.writer is None: from torch.utils.tensorboard import SummaryWriter current_time = datetime.now().strftime("%Y%m%d_%H%M%S") agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}') # Track performance metrics trades_count = 0 winning_trades = 0 total_profit = 0 max_drawdown = 0 peak_balance = env.balance step_counter = 0 # Create directory for trade logs os.makedirs('trade_logs', exist_ok=True) current_time = datetime.now().strftime("%Y%m%d_%H%M%S") trade_log_path = f'trade_logs/trades_ws_{current_time}.csv' with open(trade_log_path, 'w') as f: f.write("timestamp,action,price,position_size,balance,pnl\n") try: # Initialize WebSocket connection and get historical data websocket, initial_candles = await initialize_websocket_data_stream(symbol, timeframe) if websocket is None or not initial_candles: logger.error("Failed to initialize WebSocket data stream") return # Load initial historical data into the environment logger.info(f"Loading {len(initial_candles)} initial candles into environment") for candle in initial_candles: env.add_data(candle) # Reset environment with historical data env.reset() # Initialize futures trading if not in demo mode exchange = None if not demo: # Import ccxt for exchange initialization import ccxt.async_support as ccxt_async # Initialize exchange for order execution exchange = await initialize_exchange() if exchange: try: await env.initialize_futures(exchange) logger.info(f"Futures trading initialized with {leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True # Start WebSocket processing in the background websocket_task = asyncio.create_task( process_websocket_ticks(websocket, env, agent, demo, timeframe) ) # Main tracking loop prev_position = 'flat' while True: try: # Check if position has changed if env.position != prev_position: trades_count += 1 if hasattr(env, 'last_trade_profit') and env.last_trade_profit > 0: winning_trades += 1 if hasattr(env, 'last_trade_profit'): total_profit += env.last_trade_profit # Log trade details current_time = datetime.now().isoformat() action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE" with open(trade_log_path, 'a') as f: f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n") logger.info(f"Trade executed: {action_name} at ${env.current_price:.2f}, PnL: ${getattr(env, 'last_trade_profit', 0):.2f}") # Update performance metrics if env.balance > peak_balance: peak_balance = env.balance current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 if current_drawdown > max_drawdown: max_drawdown = current_drawdown # Update TensorBoard metrics step_counter += 1 if step_counter % 10 == 0: # Update every 10 steps agent.writer.add_scalar('Live/Balance', env.balance, step_counter) agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) # Update chart visualization if step_counter % 30 == 0 or env.position != prev_position: agent.add_chart_to_tensorboard(env, step_counter) # Log performance summary if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) performance_text = f""" **Live Trading Performance** Balance: ${env.balance:.2f} Total PnL: ${env.total_pnl:.2f} Trades: {trades_count} Win Rate: {win_rate:.1f}% Max Drawdown: {max_drawdown*100:.1f}% """ agent.writer.add_text('Performance', performance_text, step_counter) prev_position = env.position # Sleep for a short time to prevent CPU hogging await asyncio.sleep(1) except Exception as e: logger.error(f"Error in live trading monitor loop: {str(e)}") logger.error(traceback.format_exc()) await asyncio.sleep(10) # Wait longer after an error except KeyboardInterrupt: logger.info("Live trading stopped by user") # Cancel the WebSocket task if 'websocket_task' in locals() and not websocket_task.done(): websocket_task.cancel() try: await websocket_task except asyncio.CancelledError: pass # Close the exchange connection if it exists if exchange: await exchange.close() # Final performance report if trades_count > 0: win_rate = (winning_trades / trades_count) * 100 logger.info(f"Trading session summary:") logger.info(f"Total trades: {trades_count}") logger.info(f"Win rate: {win_rate:.1f}%") logger.info(f"Final balance: ${env.balance:.2f}") logger.info(f"Total profit: ${total_profit:.2f}") except Exception as e: logger.error(f"Critical error in live trading: {str(e)}") logger.error(traceback.format_exc()) finally: # Make sure to close WebSocket if 'websocket' in locals() and websocket: await websocket.close() # Close the exchange connection if it exists if 'exchange' in locals() and exchange: await exchange.close() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: logger.info("Program terminated by user")