import os import time import json import numpy as np import pandas as pd from datetime import datetime import random import logging import asyncio import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from collections import deque, namedtuple from dotenv import load_dotenv import ccxt import websockets from torch.utils.tensorboard import SummaryWriter import torch.cuda.amp as amp # Add this import at the top from sklearn.preprocessing import MinMaxScaler import copy import argparse import traceback import io import matplotlib.dates as mdates from matplotlib.figure import Figure from PIL import Image import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime from datetime import datetime as dt from collections import defaultdict from gym.spaces import Discrete, Box import csv import gc import shutil import math import platform import ctypes # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()] ) logger = logging.getLogger("trading_bot") # Load environment variables load_dotenv() MEXC_API_KEY = os.getenv('MEXC_API_KEY') MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY') # Constants INITIAL_BALANCE = 100 # USD MAX_LEVERAGE = 100 STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5% MEMORY_SIZE = 100000 BATCH_SIZE = 64 GAMMA = 0.99 # Discount factor EPSILON_START = 1.0 EPSILON_END = 0.05 EPSILON_DECAY = 10000 STATE_SIZE = 64 # Size of our state representation LEARNING_RATE = 1e-4 TARGET_UPDATE = 10 # Update target network every 10 episodes # Experience replay tuple Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done']) # Add this function near the top of the file, after the imports but before any classes def find_local_extrema(prices, window=5): """ Find local extrema (tops and bottoms) in price series. Args: prices: Array of price values window: Window size for finding extrema Returns: Tuple of (tops, bottoms) indices """ tops = [] bottoms = [] if len(prices) < window * 2 + 1: return tops, bottoms try: # Use peak detection algorithms from scipy if available from scipy.signal import find_peaks # Find peaks (tops) peaks, _ = find_peaks(prices, distance=window) tops = list(peaks) # Find valleys (bottoms) by inverting the prices valleys, _ = find_peaks(-prices, distance=window) bottoms = list(valleys) # Optional: Filter extrema for significance if len(tops) > 0 and len(bottoms) > 0: # Calculate average price move avg_move = np.mean(np.abs(np.diff(prices))) # Filter tops and bottoms for significant moves filtered_tops = [] for top in tops: # Check if this top is significantly higher than surrounding points if top > window and top < len(prices) - window: surrounding_min = min(prices[top-window:top+window]) if prices[top] - surrounding_min > avg_move * 1.5: # 1.5x average move filtered_tops.append(top) filtered_bottoms = [] for bottom in bottoms: # Check if this bottom is significantly lower than surrounding points if bottom > window and bottom < len(prices) - window: surrounding_max = max(prices[bottom-window:bottom+window]) if surrounding_max - prices[bottom] > avg_move * 1.5: # 1.5x average move filtered_bottoms.append(bottom) tops = filtered_tops bottoms = filtered_bottoms except ImportError: # Fallback to manual detection if scipy is not available for i in range(window, len(prices) - window): # Check if this point is a local maximum if all(prices[i] >= prices[i - j] for j in range(1, window + 1)) and \ all(prices[i] >= prices[i + j] for j in range(1, window + 1)): tops.append(i) # Check if this point is a local minimum if all(prices[i] <= prices[i - j] for j in range(1, window + 1)) and \ all(prices[i] <= prices[i + j] for j in range(1, window + 1)): bottoms.append(i) return tops, bottoms class ReplayMemory: def __init__(self, capacity): self.memory = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.memory.append(Experience(state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class DQN(nn.Module): """Deep Q-Network with enhanced architecture""" def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() self.network = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads) self.hidden_size = hidden_size self.lstm_layers = lstm_layers self.attention_heads = attention_heads def forward(self, state, x_1s=None, x_1m=None, x_1h=None, x_1d=None): # Pass through to LSTMAttentionDQN if x_1m is not None and x_1h is not None and x_1d is not None: return self.network(state, x_1s, x_1m, x_1h, x_1d) else: return self.network(state) class PricePredictionModel(nn.Module): def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): super(PricePredictionModel, self).__init__() self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2) self.fc = nn.Linear(hidden_size, output_size) self.scaler = MinMaxScaler(feature_range=(0, 1)) self.is_fitted = False def forward(self, x): # x shape: [batch_size, seq_len, 1] lstm_out, _ = self.lstm(x) # Use the last time step output predictions = self.fc(lstm_out[:, -1, :]) return predictions def preprocess(self, data): # Reshape data for scaler data_reshaped = np.array(data).reshape(-1, 1) # Fit scaler if not already fitted if not self.is_fitted: self.scaler.fit(data_reshaped) self.is_fitted = True # Transform data scaled_data = self.scaler.transform(data_reshaped) return scaled_data def postprocess(self, scaled_predictions): # Inverse transform to get actual price values return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten() def predict_next_candles(self, price_history, num_candles=5): if len(price_history) < 30: # Need enough history return np.zeros(num_candles) # Preprocess data scaled_data = self.preprocess(price_history) # Create sequence sequence = scaled_data[-30:].reshape(1, 30, 1) sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device) # Get predictions with torch.no_grad(): scaled_predictions = self(sequence_tensor).cpu().numpy()[0] # Postprocess predictions predictions = self.postprocess(scaled_predictions) return predictions def train_on_new_data(self, price_history, optimizer, epochs=10): if len(price_history) < 35: # Need enough history for training return 0.0 # Preprocess data scaled_data = self.preprocess(price_history) # Create sequences and targets sequences = [] targets = [] for i in range(len(scaled_data) - 35): # Sequence: 30 time steps seq = scaled_data[i:i+30] # Target: next 5 time steps target = scaled_data[i+30:i+35].flatten() sequences.append(seq) targets.append(target) if not sequences: # If no sequences were created return 0.0 # Convert to tensors sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device) targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device) # Training loop total_loss = 0 for _ in range(epochs): # Forward pass predictions = self(sequences_tensor) # Calculate loss loss = F.mse_loss(predictions, targets_tensor) # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / epochs class TradingEnvironment: def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50, window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance', symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000): """Initialize the trading environment""" self.api_key = api_key self.api_secret = api_secret self.exchange_id = exchange_id self.symbol = symbol self.timeframe = timeframe self.init_length = init_length # TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data try: # Initialize exchange if API credentials are provided if api_key and api_secret: self.exchange = initialize_exchange(exchange_id, api_key, api_secret) logger.info(f"Exchange initialized: {exchange_id}") # Fetch historical data self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length) if not self.data: raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}") self.data_format_is_list = isinstance(self.data[0], list) logger.info(f"Loaded {len(self.data)} candles from exchange") elif data is not None: # Use provided data self.data = data self.data_format_is_list = isinstance(self.data[0], list) logger.info(f"Using provided data with {len(self.data)} candles") else: # Initialize with empty data, we'll load it later with fetch_initial_data logger.warning("No data provided, initializing with empty data") self.data = [] self.data_format_is_list = True except Exception as e: logger.error(f"Error initializing environment: {e}") raise # Initialize features and feature extractors if features is not None: self.features = features # Create a dictionary of features self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)} else: # Initialize features as a dictionary, not a list self.features = { 'price': [], 'volume': [], 'rsi': [], 'macd': [], 'macd_signal': [], 'macd_hist': [], 'bollinger_upper': [], 'bollinger_mid': [], 'bollinger_lower': [], 'stoch_k': [], 'stoch_d': [], 'ema_9': [], 'ema_21': [], 'atr': [] } self.features_dict = {} if feature_extractors is None: feature_extractors = [] self.feature_extractors = feature_extractors # Environment parameters self.initial_balance = initial_balance self.balance = initial_balance self.leverage = leverage self.position = 'flat' # 'flat', 'long', or 'short' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 self.commission = commission self.total_pnl = 0 self.total_fees = 0.0 # Track total fees paid self.trades = [] self.trade_signals = [] self.current_step = 0 self.window_size = window_size self.max_steps = max_steps self.peak_balance = initial_balance self.max_drawdown = 0 self.current_price = 0 self.win_count = 0 self.loss_count = 0 self.min_position_size = 100 # Minimum position size in USD # Track candle patterns and reversal points self.patterns = {} self.reversal_points = [] # Define observation and action spaces num_features = len(self.features) if hasattr(self, 'features') and self.features else 0 state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32) # Check if we have enough data if len(self.data) < self.window_size: logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}") def calculate_reward(self, action): """Calculate reward based on the action taken""" reward = 0 # Base reward structure if self.position == 'flat': if action == 0: # HOLD when flat reward = 0.01 # Small reward for holding when no position elif action == 1: # BUY/LONG # Check for buy signal in CNN patterns if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns: buy_confidence = self.cnn_patterns['long_confidence'] # Scale by confidence reward = 0.1 * buy_confidence * 10 else: reward = 0.1 # Default reward for taking a position # Apply fee penalty if self.position_size > 0: fee = (self.position_size / 1900) * 1 fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 reward -= fee_penalty elif action == 2: # SELL/SHORT # Check for sell signal in CNN patterns if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns: sell_confidence = self.cnn_patterns['short_confidence'] # Scale by confidence reward = 0.1 * sell_confidence * 10 else: reward = 0.1 # Default reward for taking a position # Apply fee penalty if self.position_size > 0: fee = (self.position_size / 1900) * 1 fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 reward -= fee_penalty elif action == 3: # CLOSE when no position reward = -0.1 # Penalty for trying to close no position elif self.position == 'long': if action == 0: # HOLD long position # Calculate price change since entry price_change = (self.current_price - self.entry_price) / self.entry_price # Reward or penalize based on price movement if price_change > 0: reward = price_change * 10 # Reward for holding profitable position else: reward = price_change * 5 # Smaller penalty for holding losing position elif action == 1: # BUY when already long reward = -0.1 # Penalty for redundant action elif action == 2: # SELL when long (reversal) # Calculate PnL pnl_percent = (self.current_price - self.entry_price) / self.entry_price if pnl_percent > 0: reward = -0.5 # Penalty for closing profitable long position to go short else: # Check for sell signal in CNN patterns if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns: sell_confidence = self.cnn_patterns['short_confidence'] reward = 0.2 * sell_confidence * 10 # Reward for correct reversal else: reward = 0.2 # Default reward for cutting loss # Apply fee penalty if self.position_size > 0: fee = (self.position_size / 1900) * 1 fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 reward -= fee_penalty elif action == 3: # CLOSE long position # Calculate PnL pnl_percent = (self.current_price - self.entry_price) / self.entry_price if pnl_percent > 0: reward = pnl_percent * 15 # Higher reward for taking profit else: reward = pnl_percent * 5 # Smaller penalty for cutting loss # Apply fee penalty if self.position_size > 0: fee = (self.position_size / 1900) * 1 fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 reward -= fee_penalty elif self.position == 'short': if action == 0: # HOLD short position # Calculate price change since entry price_change = (self.entry_price - self.current_price) / self.entry_price # Reward or penalize based on price movement if price_change > 0: reward = price_change * 10 # Reward for holding profitable position else: reward = price_change * 5 # Smaller penalty for holding losing position elif action == 1: # BUY when short (reversal) # Calculate PnL pnl_percent = (self.entry_price - self.current_price) / self.entry_price if pnl_percent > 0: reward = -0.5 # Penalty for closing profitable short position to go long else: # Check for buy signal in CNN patterns if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns: buy_confidence = self.cnn_patterns['long_confidence'] reward = 0.2 * buy_confidence * 10 # Reward for correct reversal else: reward = 0.2 # Default reward for cutting loss # Apply fee penalty if self.position_size > 0: fee = (self.position_size / 1900) * 1 fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 reward -= fee_penalty elif action == 2: # SELL when already short reward = -0.1 # Penalty for redundant action elif action == 3: # CLOSE short position # Calculate PnL pnl_percent = (self.entry_price - self.current_price) / self.entry_price if pnl_percent > 0: reward = pnl_percent * 15 # Higher reward for taking profit else: reward = pnl_percent * 5 # Smaller penalty for cutting loss # Apply fee penalty if self.position_size > 0: fee = (self.position_size / 1900) * 1 fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 reward -= fee_penalty return reward def reset(self): """Reset the environment to its initial state and return the initial observation""" self.balance = self.initial_balance self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 self.current_step = 0 self.trades = [] self.trade_signals = [] self.total_pnl = 0.0 self.total_fees = 0.0 self.peak_balance = self.initial_balance self.max_drawdown = 0.0 self.win_count = 0 self.loss_count = 0 return self.get_state() def add_data(self, candle): """Add a new candle to the data""" # Check if candle is a list or dictionary if isinstance(candle, list): self.data_format_is_list = True self.data.append(candle) self.current_price = candle[4] # Close price is at index 4 else: self.data_format_is_list = False self.data.append(candle) self.current_price = candle['close'] self._update_features() def _initialize_features(self): """Initialize technical indicators and features""" if len(self.data) < 30: return # Convert data to pandas DataFrame for easier calculation if self.data_format_is_list: # Convert list format to DataFrame df = pd.DataFrame(self.data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) else: # Dictionary format df = pd.DataFrame(self.data) # Basic price and volume self.features['price'] = df['close'].values self.features['volume'] = df['volume'].values # Calculate RSI (14 periods) delta = df['close'].diff() gain = delta.where(delta > 0, 0).rolling(window=14).mean() loss = -delta.where(delta < 0, 0).rolling(window=14).mean() rs = gain / loss self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values # Calculate MACD ema12 = df['close'].ewm(span=12, adjust=False).mean() ema26 = df['close'].ewm(span=26, adjust=False).mean() macd = ema12 - ema26 signal = macd.ewm(span=9, adjust=False).mean() self.features['macd'] = macd.values self.features['macd_signal'] = signal.values self.features['macd_hist'] = (macd - signal).values # Calculate Bollinger Bands sma20 = df['close'].rolling(window=20).mean() std20 = df['close'].rolling(window=20).std() self.features['bollinger_upper'] = (sma20 + 2 * std20).values self.features['bollinger_mid'] = sma20.values self.features['bollinger_lower'] = (sma20 - 2 * std20).values # Calculate Stochastic Oscillator low_14 = df['low'].rolling(window=14).min() high_14 = df['high'].rolling(window=14).max() k = 100 * ((df['close'] - low_14) / (high_14 - low_14)) self.features['stoch_k'] = k.values self.features['stoch_d'] = k.rolling(window=3).mean().values # Calculate EMAs self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values # Calculate ATR high_low = df['high'] - df['low'] high_close = (df['high'] - df['close'].shift()).abs() low_close = (df['low'] - df['close'].shift()).abs() tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values def _update_features(self): """Update technical indicators with new data""" self._initialize_features() # Recalculate all features async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch initial historical data for the environment""" try: logger.info(f"Fetching initial data for {symbol}") # Use the refactored fetch method data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) # Update environment with fetched data if data: self.data = data self._initialize_features() logger.info(f"Initialized environment with {len(data)} candles") else: logger.warning("No initial data received") return len(data) > 0 except Exception as e: logger.error(f"Error fetching initial data: {e}") return False def step(self, action): """Take an action in the environment and return the next state, reward, and done flag""" # Check if we have enough data if self.current_step >= len(self.data) - 1: # We've reached the end of data done = True next_state = self.get_state() info = { 'action': 'none', 'price': self.current_price, 'balance': self.balance, 'position': self.position, 'pnl': self.total_pnl } return next_state, 0, done, info # Circuit breaker after consecutive losses if self.count_consecutive_losses() >= 5: logger.warning("Circuit breaker triggered after 5 consecutive losses") return self.get_state(), -1, True, {'action': 'circuit_breaker_triggered'} # Reduce leverage in volatile markets if self.is_volatile_market(): self.leverage = MAX_LEVERAGE * 0.5 # Half leverage in volatile markets else: self.leverage = MAX_LEVERAGE # Store current price before taking action if self.data_format_is_list: self.current_price = self.data[self.current_step][4] # Close price else: self.current_price = self.data[self.current_step]['close'] # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) reward = self.calculate_reward(action) # Execute the action initial_balance = self.balance # Store initial balance to calculate PnL # Open long position if action == 1 and self.position != 'long': if self.position == 'short': # Close short position first if self.position_size > 0: # Calculate PnL pnl_percent = (self.entry_price - self.current_price) / self.entry_price pnl_dollar = pnl_percent * self.position_size * self.leverage # Update balance and record trade self.balance += pnl_dollar self.total_pnl += pnl_dollar # Apply trading fee (1 USD per 1.9k position) fee = (self.position_size / 1900) * 1 self.balance -= fee self.total_fees += fee # Record trade trade_duration = self.current_step - self.entry_index if self.data_format_is_list: timestamp = self.data[self.current_step][0] # Timestamp else: timestamp = self.data[self.current_step]['timestamp'] self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'fee': fee, 'net_pnl': pnl_dollar - fee, 'duration': trade_duration, 'timestamp': timestamp, 'reason': 'action_change' }) # Update win/loss count if pnl_dollar > 0: self.win_count += 1 else: self.loss_count += 1 # Now open long position self.position = 'long' self.entry_price = self.current_price self.entry_index = self.current_step # Calculate position size with risk management self.position_size = self.calculate_position_size() # Apply trading fee (1 USD per 1.9k position) fee = (self.position_size / 1900) * 1 self.balance -= fee self.total_fees += fee # Set stop loss and take profit sl_percent = 0.02 # 2% stop loss tp_percent = 0.04 # 4% take profit self.stop_loss = self.entry_price * (1 - sl_percent) self.take_profit = self.entry_price * (1 + tp_percent) # Open short position elif action == 2 and self.position != 'short': if self.position == 'long': # Close long position first if self.position_size > 0: # Calculate PnL pnl_percent = (self.current_price - self.entry_price) / self.entry_price pnl_dollar = pnl_percent * self.position_size * self.leverage # Update balance and record trade self.balance += pnl_dollar self.total_pnl += pnl_dollar # Apply trading fee (1 USD per 1.9k position) fee = (self.position_size / 1900) * 1 self.balance -= fee self.total_fees += fee # Record trade trade_duration = self.current_step - self.entry_index if self.data_format_is_list: timestamp = self.data[self.current_step][0] # Timestamp else: timestamp = self.data[self.current_step]['timestamp'] self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'fee': fee, 'net_pnl': pnl_dollar - fee, 'duration': trade_duration, 'timestamp': timestamp, 'reason': 'action_change' }) # Update win/loss count if pnl_dollar > 0: self.win_count += 1 else: self.loss_count += 1 # Now open short position self.position = 'short' self.entry_price = self.current_price self.entry_index = self.current_step # Calculate position size with risk management self.position_size = self.calculate_position_size() # Apply trading fee (1 USD per 1.9k position) fee = (self.position_size / 1900) * 1 self.balance -= fee self.total_fees += fee # Set stop loss and take profit sl_percent = 0.02 # 2% stop loss tp_percent = 0.04 # 4% take profit self.stop_loss = self.entry_price * (1 + sl_percent) self.take_profit = self.entry_price * (1 - tp_percent) # Close position elif action == 3 and self.position != 'flat': if self.position == 'long': # Calculate PnL pnl_percent = (self.current_price - self.entry_price) / self.entry_price pnl_dollar = pnl_percent * self.position_size * self.leverage # Update balance and record trade self.balance += pnl_dollar self.total_pnl += pnl_dollar # Apply trading fee (1 USD per 1.9k position) fee = (self.position_size / 1900) * 1 self.balance -= fee self.total_fees += fee # Record trade trade_duration = self.current_step - self.entry_index if self.data_format_is_list: timestamp = self.data[self.current_step][0] # Timestamp else: timestamp = self.data[self.current_step]['timestamp'] self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'fee': fee, 'net_pnl': pnl_dollar - fee, 'duration': trade_duration, 'timestamp': timestamp, 'reason': 'close_action' }) # Update win/loss count if pnl_dollar > 0: self.win_count += 1 else: self.loss_count += 1 elif self.position == 'short': # Calculate PnL pnl_percent = (self.entry_price - self.current_price) / self.entry_price pnl_dollar = pnl_percent * self.position_size * self.leverage # Update balance and record trade self.balance += pnl_dollar self.total_pnl += pnl_dollar # Apply trading fee (1 USD per 1.9k position) fee = (self.position_size / 1900) * 1 self.balance -= fee self.total_fees += fee # Record trade trade_duration = self.current_step - self.entry_index if self.data_format_is_list: timestamp = self.data[self.current_step][0] # Timestamp else: timestamp = self.data[self.current_step]['timestamp'] self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'fee': fee, 'net_pnl': pnl_dollar - fee, 'duration': trade_duration, 'timestamp': timestamp, 'reason': 'close_action' }) # Update win/loss count if pnl_dollar > 0: self.win_count += 1 else: self.loss_count += 1 # Reset position self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 # Record trade signal for visualization if action > 0: # If not HOLD signal_type = None if action == 1: # BUY/LONG signal_type = 'buy' elif action == 2: # SELL/SHORT signal_type = 'sell' elif action == 3: # CLOSE if self.position == 'long': signal_type = 'close_long' elif self.position == 'short': signal_type = 'close_short' if signal_type: if self.data_format_is_list: timestamp = self.data[self.current_step][0] # Timestamp else: timestamp = self.data[self.current_step]['timestamp'] self.trade_signals.append({ 'timestamp': timestamp, 'price': self.current_price, 'type': signal_type, 'balance': self.balance, 'pnl': self.total_pnl }) # Check for stop loss / take profit hits self.check_sl_tp() # Move to next step self.current_step += 1 done = self.current_step >= len(self.data) - 1 # Get new state next_state = self.get_state() # Update peak balance and drawdown if self.balance > self.peak_balance: self.peak_balance = self.balance current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0 self.max_drawdown = max(self.max_drawdown, current_drawdown) # Create info dictionary info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': self.current_price, 'balance': self.balance, 'position': self.position, 'pnl': self.total_pnl, 'fees': self.total_fees, 'net_pnl': self.total_pnl - self.total_fees } return next_state, reward, done, info def check_sl_tp(self): """Check if stop loss or take profit has been hit with improved trailing stop""" if self.position == 'flat': return if self.position == 'long': # Implement trailing stop loss if in profit if self.current_price > self.entry_price * 1.01: self.stop_loss = max(self.stop_loss, self.current_price * 0.995) # Trail at 0.5% below current price # Check stop loss if self.current_price <= self.stop_loss: # Stop loss hit pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'stop_loss' }) if pnl_dollar > 0: self.win_count += 1 else: self.loss_count += 1 logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 # Check take profit elif self.current_price >= self.take_profit: # Take profit hit pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'take_profit' }) self.win_count += 1 logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': # Implement trailing stop loss if in profit if self.current_price < self.entry_price * 0.99: self.stop_loss = min(self.stop_loss, self.current_price * 1.005) # Trail at 0.5% above current price # Check stop loss if self.current_price >= self.stop_loss: # Stop loss hit pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'stop_loss' }) if pnl_dollar > 0: self.win_count += 1 else: self.loss_count += 1 logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 # Check take profit elif self.current_price <= self.take_profit: # Take profit hit pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100 pnl_dollar = pnl_percent / 100 * self.position_size # Apply fees pnl_dollar -= self.calculate_fees(self.position_size) # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar # Record trade trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': trade_duration, 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'take_profit' }) self.win_count += 1 logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") # Reset position self.position = 'flat' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 def get_state(self): """Create state representation for the agent with enhanced features""" # Ensure we have enough data if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0: # Return zeros if not enough data return np.zeros(STATE_SIZE) # Create a normalized state vector with recent price action and indicators state_components = [] # Safely get the latest price try: latest_price = self.features['price'][-1] except IndexError: # If we can't get the latest price, return zeros return np.zeros(STATE_SIZE) # Safely get price features try: # Price features (normalize recent prices by the latest price) price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0 state_components.append(price_features) except (IndexError, ZeroDivisionError): # If we can't get price features, use zeros state_components.append(np.zeros(10)) # Safely get volume features try: # Volume features (normalize by max volume) max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1 vol_features = np.array(self.features['volume'][-5:]) / max_vol state_components.append(vol_features) except (IndexError, ZeroDivisionError): # If we can't get volume features, use zeros state_components.append(np.zeros(5)) # Technical indicators rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1 state_components.append(rsi) # MACD (normalize) macd_vals = np.array(self.features['macd'][-3:]) macd_signal = np.array(self.features['macd_signal'][-3:]) macd_hist = np.array(self.features['macd_hist'][-3:]) macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5) macd_norm = macd_vals / macd_scale macd_signal_norm = macd_signal / macd_scale macd_hist_norm = macd_hist / macd_scale state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm]) # Bollinger position (where is price relative to bands) bb_upper = np.array(self.features['bollinger_upper'][-3:]) bb_lower = np.array(self.features['bollinger_lower'][-3:]) bb_mid = np.array(self.features['bollinger_mid'][-3:]) price = np.array(self.features['price'][-3:]) # Calculate position of price within Bollinger Bands (0 to 1) bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)] state_components.append(np.array(bb_pos)) # Stochastic oscillator state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0) state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0) # Add predicted prices (if available) if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: # Normalize predictions relative to current price pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0 state_components.append(pred_norm) else: # Add zeros if no predictions state_components.append(np.zeros(3)) # Add extrema signals (if available) if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0: # Get recent signals idx = len(self.optimal_signals) - 5 if idx < 0: idx = 0 recent_signals = self.optimal_signals[idx:idx+5] # Pad if needed if len(recent_signals) < 5: recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant') state_components.append(recent_signals) else: # Add zeros if no signals state_components.append(np.zeros(5)) # Position info position_info = np.zeros(5) if self.position == 'long': position_info[0] = 1.0 # Position is long position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss % position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance elif self.position == 'short': position_info[0] = -1.0 # Position is short position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL % position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss % position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit % position_info[4] = self.position_size / self.balance # Position size relative to balance state_components.append(position_info) # NEW FEATURES START HERE # 1. Price momentum features (rate of change) if len(self.features['price']) >= 20: roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0 momentum_features = np.array([roc_5, roc_10, roc_20]) state_components.append(momentum_features) else: state_components.append(np.zeros(3)) # 2. Volatility features if len(self.features['price']) >= 20: # Calculate price returns returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] # Calculate volatility (standard deviation of returns) volatility = np.std(returns) # Calculate normalized high-low range based on data format if self.data_format_is_list: # List format: high at index 2, low at index 3, close at index 4 high_low_range = np.mean([ (self.data[i][2] - self.data[i][3]) / self.data[i][4] for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1)) ]) if len(self.data) > 0 else 0 else: # Dictionary format high_low_range = np.mean([ (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1)) ]) if len(self.data) > 0 else 0 # ATR normalized by price atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 volatility_features = np.array([volatility, high_low_range, atr_norm]) state_components.append(volatility_features) else: state_components.append(np.zeros(3)) # 3. Market regime features if len(self.features['price']) >= 50: # Trend strength (ADX-like measure) ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price trend_strength = abs(ema9 - ema21) / ema21 # Detect if in range or trending is_range_bound = 1.0 if self.is_uncertain_market() else 0.0 is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0 # Detect if near support/resistance near_support = 1.0 if self.is_near_support() else 0.0 near_resistance = 1.0 if self.is_near_resistance() else 0.0 market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance]) state_components.append(market_regime) else: state_components.append(np.zeros(5)) # 4. Trade history features if len(self.trades) > 0: # Recent win/loss ratio recent_trades = self.trades[-min(10, len(self.trades)):] win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades) # Average profit/loss avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0 avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0 # Normalize by balance avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0 avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0 # Last trade result last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0 trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl]) state_components.append(trade_history) else: state_components.append(np.zeros(4)) # Combine all features state = np.concatenate([comp.flatten() for comp in state_components]) # Replace any NaN or infinite values state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) # Ensure the state has the correct size if len(state) != STATE_SIZE: logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}") # Pad or truncate to match expected size if len(state) < STATE_SIZE: state = np.pad(state, (0, STATE_SIZE - len(state))) else: state = state[:STATE_SIZE] return state def get_expanded_state_size(self): """Calculate the size of the expanded state representation""" # Create a dummy state to get its size state = self.get_state() return len(state) async def expand_model_with_new_features(agent, env): """Expand the model to handle new features without retraining from scratch""" # Get the new state size new_state_size = env.get_expanded_state_size() # Only expand if the new state size is larger if new_state_size > agent.state_size: logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})") # Expand the model success = agent.expand_model( new_state_size=new_state_size, new_hidden_size=512, # Increase hidden size for more capacity new_lstm_layers=3, # More layers for deeper patterns new_attention_heads=8 # More attention heads for complex relationships ) if success: logger.info(f"Model successfully expanded to handle {new_state_size} features") return True else: logger.error("Failed to expand model") return False else: logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient") return True def calculate_reward(self, action): """ Calculate reward for taking the given action. Args: action: The action taken (0=hold, 1=buy, 2=sell, 3=close) Returns: The calculated reward """ reward = 0 # Get current price if self.data_format_is_list: current_price = self.data[self.current_step][4] # Close price else: current_price = self.data[self.current_step]['close'] # Base reward component based on price movement price_change_pct = 0 if self.current_step > 0: if self.data_format_is_list: prev_price = self.data[self.current_step-1][4] # Previous close else: prev_price = self.data[self.current_step-1]['close'] price_change_pct = (current_price - prev_price) / prev_price # Check if we have CNN patterns available pattern_confidence = 0 if hasattr(self, 'cnn_patterns'): if action == 1 and 'long_confidence' in self.cnn_patterns: # Buy action pattern_confidence = self.cnn_patterns['long_confidence'] elif action == 2 and 'short_confidence' in self.cnn_patterns: # Sell action pattern_confidence = self.cnn_patterns['short_confidence'] # Action-specific rewards if action == 0: # HOLD # Small positive reward for holding in the right direction of market movement if self.position == 'long' and price_change_pct > 0: reward += 0.1 + price_change_pct * 10 elif self.position == 'short' and price_change_pct < 0: reward += 0.1 + abs(price_change_pct) * 10 else: # Small negative reward for holding in the wrong direction reward -= 0.1 elif action == 1 or action == 2: # BUY or SELL # Apply trading fee as negative reward (1 USD per 1.9k position size) position_size = self.calculate_position_size() fee = (position_size / 1900) * 1 # Trading fee in USD # Penalty for fee fee_penalty = fee / 10 # Scale down to make it a reasonable penalty reward -= fee_penalty # Logging if hasattr(self, 'total_fees'): self.total_fees += fee else: self.total_fees = fee elif action == 3: # CLOSE # Apply trading fee as negative reward (1 USD per 1.9k position size) fee = (self.position_size / 1900) * 1 # Trading fee in USD # Penalty for fee fee_penalty = fee / 10 # Scale down to make it a reasonable penalty reward -= fee_penalty # Logging if hasattr(self, 'total_fees'): self.total_fees += fee else: self.total_fees = fee # Add CNN pattern confidence to reward reward += pattern_confidence * 10 return reward async def initialize_futures(self, exchange): """Initialize futures trading parameters""" if not self.demo: try: # Set up futures trading parameters await exchange.set_position_mode(True) # Hedge mode await exchange.set_margin_mode("cross", symbol=self.futures_symbol) await exchange.set_leverage(self.leverage, symbol=self.futures_symbol) logger.info(f"Futures initialized with {self.leverage}x leverage") except Exception as e: logger.error(f"Failed to initialize futures trading: {str(e)}") logger.info("Falling back to demo mode for safety") demo = True async def execute_real_trade(self, exchange, action, current_price): """Execute real futures trade on MEXC""" try: position_size = self.calculate_position_size() if action == 1: # Open long order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='buy', amount=position_size, params={'positionSide': 'LONG'} ) logger.info(f"Opened LONG position: {order}") elif action == 2: # Open short order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='sell', amount=position_size, params={'positionSide': 'SHORT'} ) logger.info(f"Opened SHORT position: {order}") elif action == 3: # Close position position_side = 'LONG' if self.position == 'long' else 'SHORT' order = await exchange.create_order( symbol=self.futures_symbol, type='market', side='sell' if position_side == 'LONG' else 'buy', amount=self.position_size, params={'positionSide': position_side} ) logger.info(f"Closed {position_side} position: {order}") return order except Exception as e: logger.error(f"Trade execution failed: {e}") return None def trades_in_last_n_candles(self, n=20): """Count the number of trades in the last n candles""" if len(self.trades) == 0: return 0 if self.data_format_is_list: # List format: timestamp at index 0 current_time = self.data[self.current_step][0] n_candles_ago = self.data[max(0, self.current_step - n)][0] else: # Dictionary format current_time = self.data[self.current_step]['timestamp'] n_candles_ago = self.data[max(0, self.current_step - n)]['timestamp'] count = 0 for trade in reversed(self.trades): if 'timestamp' in trade and trade['timestamp'] >= n_candles_ago and trade['timestamp'] <= current_time: count += 1 else: # Older trades, we can stop counting break return count def count_consecutive_losses(self): """Count the number of consecutive losing trades""" count = 0 for trade in reversed(self.trades): if trade.get('pnl_dollar', 0) < 0: count += 1 else: break return count def is_volatile_market(self): """Determine if the current market is volatile""" if len(self.features['price']) < 20: return False recent_prices = self.features['price'][-20:] avg_price = sum(recent_prices) / len(recent_prices) volatility = sum([abs(p - avg_price) / avg_price for p in recent_prices]) / len(recent_prices) return volatility > 0.01 # 1% average deviation is considered volatile def is_uptrend(self): """Determine if the market is in an uptrend""" if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2: return False # Short-term trend short_trend = self.features['ema_9'][-1] > self.features['ema_9'][-2] # Medium-term trend medium_trend = self.features['ema_9'][-1] > self.features['ema_21'][-1] return short_trend and medium_trend def is_downtrend(self): """Determine if the market is in a downtrend""" if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2: return False # Short-term trend short_trend = self.features['ema_9'][-1] < self.features['ema_9'][-2] # Medium-term trend medium_trend = self.features['ema_9'][-1] < self.features['ema_21'][-1] return short_trend and medium_trend def calculate_position_size(self): """Calculate position size based on risk management rules""" # Reduce position size after losses consecutive_losses = self.count_consecutive_losses() risk_factor = max(0.3, 1.0 - (consecutive_losses * 0.1)) # Reduce by 10% per loss, min 30% # Calculate position size based on available balance and risk max_risk_amount = self.balance * 0.02 # Risk 2% per trade position_size = max_risk_amount / (STOP_LOSS_PERCENT / 100 * self.current_price) # Apply leverage position_size = position_size * self.leverage # Cap at available balance position_size = min(position_size, self.balance * self.leverage) return position_size * risk_factor # ... existing identify_optimal_trades method ... def update_cnn_patterns(self, candle_data=None): """ Update CNN patterns using multi-timeframe data. Args: candle_data: Dictionary containing candle data for different timeframes """ if not candle_data: return try: # Check if we have the necessary timeframes required_timeframes = ['1m', '1h', '1d'] if not all(tf in candle_data for tf in required_timeframes): logging.warning(f"Missing required timeframes for CNN pattern detection") return # Initialize patterns if not already done if not hasattr(self, 'cnn_patterns'): self.cnn_patterns = {} # Extract features from candle data features = {} # Process each timeframe for tf in required_timeframes: candles = candle_data[tf] if not candles or len(candles) < 30: continue # Convert to numpy arrays for easier processing closes = np.array([c[4] for c in candles[-100:]]) highs = np.array([c[2] for c in candles[-100:]]) lows = np.array([c[3] for c in candles[-100:]]) # Simple feature extraction # 1. Detect trends ema20 = self._calculate_ema(closes, 20) ema50 = self._calculate_ema(closes, 50) uptrend = ema20[-1] > ema50[-1] and closes[-1] > ema20[-1] downtrend = ema20[-1] < ema50[-1] and closes[-1] < ema20[-1] # 2. Detect potential reversal patterns # Find local extrema tops, bottoms = find_local_extrema(closes, window=10) # Check if we're near a bottom (potential buy) near_bottom = False bottom_confidence = 0 if bottoms and len(bottoms) > 0: last_bottom = bottoms[-1] if len(closes) - last_bottom < 5: # Recent bottom bottom_dist = abs(closes[-1] - closes[last_bottom]) / closes[last_bottom] if bottom_dist < 0.01: # Within 1% of the bottom near_bottom = True # Higher confidence if volume is increasing bottom_confidence = 0.8 - bottom_dist * 50 # 0.8 to 0.3 range # Check if we're near a top (potential sell) near_top = False top_confidence = 0 if tops and len(tops) > 0: last_top = tops[-1] if len(closes) - last_top < 5: # Recent top top_dist = abs(closes[-1] - closes[last_top]) / closes[last_top] if top_dist < 0.01: # Within 1% of the top near_top = True # Higher confidence if volume is increasing top_confidence = 0.8 - top_dist * 50 # 0.8 to 0.3 range # Store features for this timeframe features[tf] = { 'uptrend': uptrend, 'downtrend': downtrend, 'near_bottom': near_bottom, 'bottom_confidence': bottom_confidence, 'near_top': near_top, 'top_confidence': top_confidence } # Combine features across timeframes to get overall pattern confidence long_confidence = 0 short_confidence = 0 # Weight each timeframe (higher weight for longer timeframes) weights = {'1m': 0.2, '1h': 0.3, '1d': 0.5} for tf, tf_features in features.items(): weight = weights.get(tf, 0.2) # Add to long confidence if tf_features['uptrend'] or tf_features['near_bottom']: long_confidence += weight * (0.6 if tf_features['uptrend'] else 0) + \ weight * (tf_features['bottom_confidence'] if tf_features['near_bottom'] else 0) # Add to short confidence if tf_features['downtrend'] or tf_features['near_top']: short_confidence += weight * (0.6 if tf_features['downtrend'] else 0) + \ weight * (tf_features['top_confidence'] if tf_features['near_top'] else 0) # Normalize confidence scores to [0, 1] long_confidence = min(1.0, long_confidence) short_confidence = min(1.0, short_confidence) # Update patterns self.cnn_patterns = { 'long_confidence': long_confidence, 'short_confidence': short_confidence, 'features': features } logging.debug(f"Updated CNN patterns - Long: {long_confidence:.2f}, Short: {short_confidence:.2f}") except Exception as e: logging.error(f"Error updating CNN patterns: {e}") def _calculate_ema(self, data, span): """Calculate exponential moving average""" alpha = 2 / (span + 1) alpha_rev = 1 - alpha ema = np.zeros_like(data) ema[0] = data[0] for i in range(1, len(data)): ema[i] = alpha * data[i] + alpha_rev * ema[i-1] return ema def is_uncertain_market(self): """Determine if the market is in an uncertain/range-bound state""" if len(self.features['price']) < 30: return False # Check if EMAs are close to each other (no clear trend) if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: ema9 = self.features['ema_9'][-1] ema21 = self.features['ema_21'][-1] # If EMAs are within 0.2% of each other, market is uncertain if abs(ema9 - ema21) / ema21 < 0.002: return True # Check if price is oscillating without clear direction if len(self.features['price']) >= 10: recent_prices = self.features['price'][-10:] ups = downs = 0 for i in range(1, len(recent_prices)): if recent_prices[i] > recent_prices[i-1]: ups += 1 else: downs += 1 # If there's a mix of ups and downs (neither dominates heavily) return abs(ups - downs) < 3 return False def is_near_support(self): """Determine if the current price is near a support level""" if len(self.features['price']) < 30: return False # Use Bollinger lower band as support if len(self.features['bollinger_lower']) > 0 and len(self.features['price']) > 0: current_price = self.features['price'][-1] lower_band = self.features['bollinger_lower'][-1] # If price is within 0.5% of the lower band if (current_price - lower_band) / current_price < 0.005: return True # Check if we're near recent lows if len(self.features['price']) >= 20: current_price = self.features['price'][-1] min_price = min(self.features['price'][-20:]) # If within 1% of recent lows if (current_price - min_price) / current_price < 0.01: return True return False def is_near_resistance(self): """Determine if the current price is near a resistance level""" if len(self.features['price']) < 30: return False # Use Bollinger upper band as resistance if len(self.features['bollinger_upper']) > 0 and len(self.features['price']) > 0: current_price = self.features['price'][-1] upper_band = self.features['bollinger_upper'][-1] # If price is within 0.5% of the upper band if (upper_band - current_price) / current_price < 0.005: return True # Check if we're near recent highs if len(self.features['price']) >= 20: current_price = self.features['price'][-1] max_price = max(self.features['price'][-20:]) # If within 1% of recent highs if (max_price - current_price) / current_price < 0.01: return True return False def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'): """ Add a candlestick chart and metrics to TensorBoard Parameters: - writer: TensorBoard writer - step: Current step - title: Title for the chart """ try: # Initialize writer if not provided if writer is None: from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() # Log basic metrics writer.add_scalar('Balance', self.balance, step) writer.add_scalar('Total_PnL', self.total_pnl, step) # Log total fees if available if hasattr(self, 'total_fees'): writer.add_scalar('Total_Fees', self.total_fees, step) writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step) # Log position info writer.add_scalar('Position_Size', self.position_size, step) # Log drawdown and win rate writer.add_scalar('Max_Drawdown', self.max_drawdown, step) win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0 writer.add_scalar('Win_Rate', win_rate, step) # Log trade count writer.add_scalar('Trade_Count', len(self.trades), step) # Check if we have enough data for candlestick chart if len(self.data) <= 0: logger.warning("No data available for candlestick chart") return # Create figure for candlestick chart (last 100 data points) start_idx = max(0, self.current_step - 100) end_idx = self.current_step # Get recent trades for visualization (last 10 trades) recent_trades = self.trades[-10:] if self.trades else [] try: fig = create_candlestick_figure( self.data[start_idx:end_idx+1], title=title, trades=recent_trades ) # Add figure to TensorBoard writer.add_figure('Candlestick_Chart', fig, step) # Close figure to free memory plt.close(fig) except Exception as e: logger.error(f"Error creating candlestick chart: {e}") except Exception as e: logger.error(f"Error adding chart to TensorBoard: {e}") # Continue execution even if chart fails def get_realtime_state(self, tick_data): """ Create a state representation optimized for real-time processing. This is a streamlined version of get_state() designed for minimal latency. TODO: Implement optimized state creation from tick data """ # This would be a simplified version of get_state that processes only # the most important features needed for real-time decision making # Example implementation: # realtime_features = { # 'price': tick_data['price'], # 'volume': tick_data['volume'], # 'ema_short': self._calculate_ema(tick_data['price'], 9), # 'ema_long': self._calculate_ema(tick_data['price'], 21), # } # Convert to tensor or numpy array in the required format # return torch.tensor([...], dtype=torch.float32) # Placeholder return np.zeros((self.observation_space.shape[0],), dtype=np.float32) # Ensure GPU usage if available def get_device(): """Get the best available device (CUDA GPU or CPU)""" if torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") # Set up for mixed precision training torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") logger.info("GPU not available, using CPU") return device # Update Agent class to use GPU properly class Agent: def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None): """Initialize Agent with architecture parameters stored as attributes""" self.state_size = state_size self.action_size = action_size self.hidden_size = hidden_size # Store hidden_size as an instance attribute self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute self.attention_heads = attention_heads # Store attention_heads as an instance attribute # Set device self.device = device if device is not None else get_device() # Initialize networks - use LSTMAttentionDQN instead of DQN self.policy_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) # Initialize optimizer self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) # Initialize replay memory self.memory = ReplayMemory(MEMORY_SIZE) # Initialize exploration parameters self.epsilon = EPSILON_START self.epsilon_start = EPSILON_START self.epsilon_end = EPSILON_END self.epsilon_decay = EPSILON_DECAY self.epsilon_min = EPSILON_END # Initialize step counter self.steps_done = 0 # Initialize TensorBoard writer self.writer = None # Initialize GradScaler for mixed precision training self.scaler = torch.amp.GradScaler('cuda') if self.device.type == "cuda" else None # Initialize candle cache for multi-timeframe data self.candle_cache = CandleCache() # Store model name for logging self.model_name = f"LSTM_Attention_DQN_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" logger.info(f"Initialized agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}") logger.info(f"Using device: {self.device}") def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): """Expand the model to handle more features or increase capacity""" logger.info(f"Expanding model: {self.state_size} → {new_state_size}, " f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}") # Save old weights old_state_dict = self.policy_net.state_dict() # Create new larger networks new_policy_net = LSTMAttentionDQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) new_target_net = LSTMAttentionDQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) # Transfer weights for common layers new_state_dict = new_policy_net.state_dict() for name, param in old_state_dict.items(): if name in new_state_dict: # If shapes match, copy directly if new_state_dict[name].shape == param.shape: new_state_dict[name] = param # For first layer, copy weights for the original input dimensions elif name == "fc1.weight": new_state_dict[name][:, :self.state_size] = param # For other layers, initialize with a strategy that preserves scale else: logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}") # Load transferred weights new_policy_net.load_state_dict(new_state_dict) new_target_net.load_state_dict(new_state_dict) # Replace networks self.policy_net = new_policy_net self.target_net = new_target_net self.target_net.eval() # Update optimizer self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) # Update state size self.state_size = new_state_size # Print new model size total_params = sum(p.numel() for p in self.policy_net.parameters()) logger.info(f"New model size: {total_params:,} parameters") return True def select_action(self, state, training=True, candle_data=None): """ Select an action using the policy network. Args: state: The current state training: Whether we're in training mode (for epsilon-greedy) candle_data: Dictionary with ['1s'-later], '1m', '1h', '1d' candle data Returns: The selected action """ # ... existing code ... # Add CNN processing if candle data is available cnn_inputs = None if candle_data and all(k in candle_data for k in [ '1m', '1h', '1d']): # Process candle data into tensors # x_1s = self.prepare_candle_tensor(candle_data['1s']) x_1m = self.prepare_candle_tensor(candle_data['1m']) x_1h = self.prepare_candle_tensor(candle_data['1h']) x_1d = self.prepare_candle_tensor(candle_data['1d']) cnn_inputs = (x_1m, x_1h, x_1d) # Use epsilon-greedy strategy during training if training and random.random() < self.epsilon: return random.randrange(self.action_size) with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) if cnn_inputs: q_values = self.policy_net(state_tensor, *cnn_inputs) else: q_values = self.policy_net(state_tensor) return q_values.max(1)[1].item() def prepare_candle_tensor(self, candles, max_candles=300): """Convert candle data to tensors for CNN input""" if not candles: # Return zeros if no candles available return torch.zeros((1, 5, max_candles), device=self.device) # Limit to the most recent candles candles = candles[-max_candles:] # Extract OHLCV data ohlcv = np.array([[c[1], c[2], c[3], c[4], c[5]] for c in candles], dtype=np.float32) # Normalize the data if len(ohlcv) > 0: # Simple min-max normalization per column min_vals = ohlcv.min(axis=0, keepdims=True) max_vals = ohlcv.max(axis=0, keepdims=True) range_vals = max_vals - min_vals range_vals[range_vals == 0] = 1 # Avoid division by zero ohlcv = (ohlcv - min_vals) / range_vals # Pad if needed padded = np.zeros((max_candles, 5), dtype=np.float32) padded[-len(ohlcv):] = ohlcv # Convert to tensor [batch, channels, sequence] tensor = torch.FloatTensor(padded.transpose(1, 0)).unsqueeze(0).to(self.device) return tensor else: return torch.zeros((1, 5, max_candles), device=self.device) def learn(self): """Learn from a batch of experiences with GPU acceleration and CNN features""" if len(self.memory) < BATCH_SIZE: return None try: # Sample a batch of experiences experiences = self.memory.sample(BATCH_SIZE) # Convert experiences to tensors states = torch.FloatTensor([e.state for e in experiences]).to(self.device) actions = torch.LongTensor([e.action for e in experiences]).to(self.device) rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device) next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device) dones = torch.FloatTensor([e.done for e in experiences]).to(self.device) # Use mixed precision for forward/backward passes if on GPU if self.device.type == "cuda" and self.scaler is not None: with torch.cuda.amp.autocast(): # Compute current Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) # Compute next Q values with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] # Compute target Q values target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) # Backward pass with gradient scaling self.optimizer.zero_grad() self.scaler.scale(loss).backward() # Clip gradients self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) # Update weights self.scaler.step(self.optimizer) self.scaler.update() else: # Standard precision for CPU # Compute Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) # Compute next Q values with target network with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) # Reshape target values to match current_q_values target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping to prevent exploding gradients torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() # Update steps done self.steps_done += 1 # Update target network if self.steps_done % TARGET_UPDATE == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) return loss.item() except Exception as e: logger.error(f"Error during learning: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None def update_epsilon(self, episode): """Update epsilon value based on episode number""" # Calculate epsilon using a linear decay formula epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ max(0, (self.epsilon_decay - episode)) / self.epsilon_decay # Update self.epsilon with the calculated value self.epsilon = max(self.epsilon_min, epsilon) return self.epsilon def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path="models/trading_agent_best_pnl.pt"): """Save the model using a robust saving approach with multiple fallbacks""" try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(path), exist_ok=True) # Call robust save function success = robust_save(self, path) if success: logger.info(f"Model saved successfully to {path}") return True else: logger.error(f"All save attempts failed for path: {path}") return False except Exception as e: logger.error(f"Error in save method: {e}") logger.error(traceback.format_exc()) return False def load(self, path="models/trading_agent_best_pnl.pt"): """Load a trained model with improved error handling for PyTorch 2.6 compatibility""" try: # First try to load with weights_only=False (for models saved with older PyTorch versions) try: logger.info(f"Attempting to load model with weights_only=False: {path}") checkpoint = torch.load(path, map_location=self.device, weights_only=False) logger.info("Model loaded successfully with weights_only=False") except Exception as e1: logger.warning(f"Failed to load with weights_only=False: {e1}") # Try with safe_globals context manager try: logger.info("Attempting to load with safe_globals context manager") import numpy as np from torch.serialization import safe_globals # Add numpy scalar to safe globals with safe_globals(['numpy._core.multiarray.scalar']): checkpoint = torch.load(path, map_location=self.device) logger.info("Model loaded successfully with safe_globals") except Exception as e2: logger.warning(f"Failed to load with safe_globals: {e2}") # Last resort: try with pickle_module=pickle logger.info("Attempting to load with pickle_module") import pickle checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False) logger.info("Model loaded successfully with pickle_module") # Load state dictionaries self.policy_net.load_state_dict(checkpoint['policy_net']) self.target_net.load_state_dict(checkpoint['target_net']) # Try to load optimizer state try: self.optimizer.load_state_dict(checkpoint['optimizer']) except Exception as e: logger.warning(f"Could not load optimizer state: {e}") # Load epsilon if available if 'epsilon' in checkpoint: self.epsilon = checkpoint['epsilon'] # Load architecture parameters if available if 'state_size' in checkpoint: self.state_size = checkpoint['state_size'] if 'action_size' in checkpoint: self.action_size = checkpoint['action_size'] if 'hidden_size' in checkpoint: self.hidden_size = checkpoint['hidden_size'] else: # If hidden_size not in checkpoint, infer from model try: self.hidden_size = self.policy_net.fc1.weight.shape[0] logger.info(f"Inferred hidden_size={self.hidden_size} from model") except: self.hidden_size = 256 # Default value logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}") if 'lstm_layers' in checkpoint: self.lstm_layers = checkpoint['lstm_layers'] else: self.lstm_layers = 2 # Default value if 'attention_heads' in checkpoint: self.attention_heads = checkpoint['attention_heads'] else: self.attention_heads = 4 # Default value logger.info(f"Model loaded successfully from {path}") except Exception as e: logger.error(f"Error loading model: {e}") import traceback logger.error(traceback.format_exc()) raise def add_chart_to_tensorboard(self, env, step): """Add candlestick chart to tensorboard and various metrics""" try: # Initialize writer if it doesn't exist if not hasattr(self, 'writer') or self.writer is None: self.writer = SummaryWriter(log_dir=f'runs/{self.model_name}') # Check if we have enough data if not hasattr(env, 'data') or len(env.data) < 20: logger.warning("Not enough data for chart in TensorBoard") return # Get position value (convert from string if needed) position_value = 0 # Default to flat if hasattr(env, 'position'): if isinstance(env.position, str): # Map string positions to numeric values position_map = {'flat': 0, 'long': 1, 'short': -1} position_value = position_map.get(env.position.lower(), 0) else: position_value = float(env.position) # Log metrics to tensorboard self.writer.add_scalar('Trading/Position', position_value, step) if hasattr(env, 'balance'): self.writer.add_scalar('Trading/Balance', env.balance, step) if hasattr(env, 'total_pnl'): self.writer.add_scalar('Trading/Total_PnL', env.total_pnl, step) if hasattr(env, 'max_drawdown'): self.writer.add_scalar('Trading/Drawdown', env.max_drawdown, step) if hasattr(env, 'win_rate'): self.writer.add_scalar('Trading/Win_Rate', env.win_rate, step) if hasattr(env, 'trade_count'): self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step) # Log trading fees if hasattr(env, 'total_fees'): self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step) # Also log net PnL (after fees) if hasattr(env, 'total_pnl'): self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step) # Add candlestick chart if we have enough data if len(env.data) >= 100: try: # Use the last 100 candles for the chart recent_data = env.data[-100:] # Get recent trades if available recent_trades = None if hasattr(env, 'trades') and len(env.trades) > 0: recent_trades = env.trades[-10:] # Last 10 trades # Create candlestick figure fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}") if fig: # Add to tensorboard self.writer.add_figure('Trading/Chart', fig, step) # Close figure to free memory plt.close(fig) except Exception as e: logger.warning(f"Error creating candlestick chart: {e}") except Exception as e: logger.error(f"Error in add_chart_to_tensorboard: {e}") def select_action_realtime(self, state): """ Select action with minimal latency for real-time trading. Optimized version of select_action for ultra-low latency requirements. TODO: Implement optimized action selection for real-time trading """ # Convert to tensor if needed state_tensor = torch.tensor(state, dtype=torch.float32) # Fast forward pass through the network with torch.no_grad(): q_values = self.policy_net.forward_realtime(state_tensor.unsqueeze(0)) # Get the action with highest Q-value action = q_values.max(1)[1].item() return action def forward_realtime(self, state): """ Optimized forward pass for real-time trading with minimal latency. TODO: Implement streamlined forward pass that prioritizes speed """ # For now, just use the regular forward pass # This could be optimized later with techniques like: # - Using a smaller model for real-time decisions # - Skipping certain layers or calculations # - Using quantized weights or other optimizations return self.forward(state) async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" # Connect to MEXC websocket uri = "wss://stream.mexc.com/ws" async with websockets.connect(uri) as websocket: # Subscribe to kline data subscribe_msg = { "method": "SUBSCRIPTION", "params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"] } await websocket.send(json.dumps(subscribe_msg)) logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines") while True: try: response = await websocket.recv() data = json.loads(response) if 'data' in data: kline = data['data'] candle = { 'timestamp': kline['t'], 'open': float(kline['o']), 'high': float(kline['h']), 'low': float(kline['l']), 'close': float(kline['c']), 'volume': float(kline['v']) } yield candle except Exception as e: logger.error(f"Websocket error: {e}") # Try to reconnect await asyncio.sleep(5) break async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, use_compact_save=False): """ Train the agent in the environment. Args: agent: The agent to train env: The trading environment num_episodes: Number of episodes to train for max_steps_per_episode: Maximum steps per episode use_compact_save: Whether to use compact save (for low disk space) Returns: Training statistics """ # Initialize TensorBoard writer if not already done try: if agent.writer is None: from torch.utils.tensorboard import SummaryWriter agent.writer = SummaryWriter(log_dir=f'runs/{agent.model_name}') writer = agent.writer except Exception as e: logging.error(f"Failed to initialize TensorBoard: {e}") writer = None # Initialize exchange for data fetching try: exchange = await initialize_exchange() logging.info("Initialized exchange for data fetching") except Exception as e: logging.error(f"Failed to initialize exchange: {e}") exchange = None # Initialize statistics tracking stats = { 'episode_rewards': [], 'episode_lengths': [], 'balances': [], 'win_rates': [], 'episode_pnls': [], 'cumulative_pnl': [], 'drawdowns': [], 'trade_counts': [], 'loss_values': [], 'fees': [], # Track fees 'net_pnl_after_fees': [] # Track net PnL after fees } # Track best models best_reward = float('-inf') best_pnl = float('-inf') best_net_pnl = float('-inf') # Track best net PnL (after fees) # Make directory for models if it doesn't exist os.makedirs('models', exist_ok=True) # Memory management function def clean_memory(): """Clean up memory to avoid memory leaks""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Start training loop for episode in range(num_episodes): try: # Clean up memory before starting a new episode clean_memory() # Reset environment state = env.reset() episode_reward = 0 episode_losses = [] # Fetch multi-timeframe data at the start of the episode candle_data = None if exchange: try: candle_data = await fetch_multi_timeframe_data( exchange, "ETH/USDT", agent.candle_cache ) # Update CNN patterns env.update_cnn_patterns(candle_data) logging.info(f"Fetched multi-timeframe data for episode {episode+1}") except Exception as e: logging.error(f"Failed to fetch candle data: {e}") # Track consecutive errors consecutive_errors = 0 max_consecutive_errors = 5 # Episode loop for step in range(max_steps_per_episode): try: # Select action using CNN-enhanced policy action = agent.select_action(state, training=True, candle_data=candle_data) # Take action next_state, reward, done, info = env.step(action) # Store transition in replay memory agent.memory.push(state, action, reward, next_state, done) # Move to the next state state = next_state # Update episode reward episode_reward += reward # Learn from experience if len(agent.memory) > BATCH_SIZE: try: loss = agent.learn() if loss is not None: episode_losses.append(loss) # Log loss to TensorBoard global_step = episode * max_steps_per_episode + step if writer: writer.add_scalar('Loss/step', loss, global_step) # Reset consecutive errors counter on successful learning consecutive_errors = 0 except Exception as e: logging.error(f"Error during learning: {e}") consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") break # Update target network periodically if step % TARGET_UPDATE == 0: agent.update_target_network() # Update price predictions and CNN patterns periodically if step % 50 == 0: try: # Update internal environment predictions if hasattr(env, 'update_price_predictions'): env.update_price_predictions() if hasattr(env, 'identify_optimal_trades'): env.identify_optimal_trades() # Fetch fresh candle data periodically if exchange: try: candle_data = await fetch_multi_timeframe_data( exchange, "ETH/USDT", agent.candle_cache ) # Update CNN patterns with the new candle data env.update_cnn_patterns(candle_data) logging.info(f"Updated multi-timeframe data at step {step}") except Exception as e: logging.error(f"Failed to fetch candle data: {e}") except Exception as e: logging.warning(f"Error updating predictions: {e}") # Clean memory periodically during long episodes if step % 200 == 0 and step > 0: clean_memory() # Add chart to TensorBoard periodically if step % 100 == 0 or (step == max_steps_per_episode - 1) or done: try: global_step = episode * max_steps_per_episode + step if writer: agent.add_chart_to_tensorboard(env, global_step) except Exception as e: logging.warning(f"Error adding chart to TensorBoard: {e}") if done: break except Exception as e: logging.error(f"Error in training step: {e}") consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") break # Calculate statistics from this episode balance = env.balance pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0 fees = env.total_fees if hasattr(env, 'total_fees') else 0 net_pnl = pnl - fees # Calculate net PnL after fees # Get trading statistics trade_analysis = None if hasattr(env, 'analyze_trades'): trade_analysis = env.analyze_trades() win_rate = trade_analysis['win_rate'] if trade_analysis and 'win_rate' in trade_analysis else 0 trade_count = trade_analysis['total_trades'] if trade_analysis and 'total_trades' in trade_analysis else 0 max_drawdown = trade_analysis['max_drawdown'] if trade_analysis and 'max_drawdown' in trade_analysis else 0 # Calculate average loss for this episode avg_loss = sum(episode_losses) / len(episode_losses) if episode_losses else 0 # Log episode metrics to TensorBoard if writer: writer.add_scalar('Reward/episode', episode_reward, episode) writer.add_scalar('Balance/episode', balance, episode) writer.add_scalar('PnL/episode', pnl, episode) writer.add_scalar('NetPnL/episode', net_pnl, episode) writer.add_scalar('Fees/episode', fees, episode) writer.add_scalar('WinRate/episode', win_rate, episode) writer.add_scalar('TradeCount/episode', trade_count, episode) writer.add_scalar('Drawdown/episode', max_drawdown, episode) writer.add_scalar('Loss/episode', avg_loss, episode) writer.add_scalar('Epsilon/episode', agent.epsilon, episode) # Update stats dictionary stats['episode_rewards'].append(episode_reward) stats['episode_lengths'].append(step + 1) stats['balances'].append(balance) stats['win_rates'].append(win_rate) stats['episode_pnls'].append(pnl) stats['drawdowns'].append(max_drawdown) stats['trade_counts'].append(trade_count) stats['loss_values'].append(avg_loss) stats['fees'].append(fees) stats['net_pnl_after_fees'].append(net_pnl) # Calculate and update cumulative PnL if len(stats['episode_pnls']) > 0: cumulative_pnl = sum(stats['episode_pnls']) if 'cumulative_pnl' not in stats: stats['cumulative_pnl'] = [] stats['cumulative_pnl'].append(cumulative_pnl) if writer: writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode) # Save model if this is the best reward or PnL if episode_reward > best_reward: best_reward = episode_reward try: if use_compact_save: success = compact_save(agent, 'models/trading_agent_best_reward.pt') else: success = agent.save('models/trading_agent_best_reward.pt') if success: logging.info(f"New best reward: {best_reward:.2f}") except Exception as e: logging.error(f"Error saving best reward model: {e}") if pnl > best_pnl: best_pnl = pnl try: if use_compact_save: success = compact_save(agent, 'models/trading_agent_best_pnl.pt') else: success = agent.save('models/trading_agent_best_pnl.pt') if success: logging.info(f"New best PnL: ${best_pnl:.2f}") except Exception as e: logging.error(f"Error saving best PnL model: {e}") # Save model if this is the best net PnL (after fees) if net_pnl > best_net_pnl: best_net_pnl = net_pnl try: if use_compact_save: success = compact_save(agent, 'models/trading_agent_best_net_pnl.pt') else: success = agent.save('models/trading_agent_best_net_pnl.pt') if success: logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") except Exception as e: logging.error(f"Error saving best net PnL model: {e}") # Save checkpoint periodically if episode % 10 == 0: try: if use_compact_save: compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt') else: agent.save(f'models/trading_agent_checkpoint_{episode}.pt') except Exception as e: logging.error(f"Error saving checkpoint model: {e}") # Update epsilon agent.update_epsilon(episode) # Log training progress logging.info(f"Episode {episode+1}/{num_episodes} | " + f"Reward: {episode_reward:.2f} | " + f"Balance: ${balance:.2f} | " + f"PnL: ${pnl:.2f} | " + f"Fees: ${fees:.2f} | " + f"Net PnL: ${net_pnl:.2f} | " + f"Win Rate: {win_rate:.2f} | " + f"Trades: {trade_count} | " + f"Loss: {avg_loss:.5f} | " + f"Epsilon: {agent.epsilon:.4f}") except Exception as e: logging.error(f"Error in episode {episode}: {e}") logging.error(traceback.format_exc()) continue # Clean memory before saving final model clean_memory() # Save final model try: if use_compact_save: compact_save(agent, 'models/trading_agent_final.pt') else: agent.save('models/trading_agent_final.pt') except Exception as e: logging.error(f"Error saving final model: {e}") # Save training statistics to file try: import pandas as pd # Make sure all arrays in stats are the same length by padding with NaN max_length = max(len(v) for k, v in stats.items() if isinstance(v, list)) for k, v in stats.items(): if isinstance(v, list) and len(v) < max_length: stats[k] = v + [float('nan')] * (max_length - len(v)) # Create dataframe and save stats_df = pd.DataFrame(stats) stats_df.to_csv('training_stats.csv', index=False) logging.info(f"Training statistics saved to training_stats.csv") except Exception as e: logging.error(f"Failed to save training statistics: {e}") logging.error(traceback.format_exc()) # Close exchange if it's still open if exchange: try: # Check if exchange has the close method (ccxt.async_support) if hasattr(exchange, 'close'): await exchange.close() logging.info("Closed exchange connection") else: logging.info("Exchange doesn't have close method (standard ccxt), skipping close") except Exception as e: logging.error(f"Error closing exchange: {e}") return stats def plot_training_results(stats): """Plot training results and save to file""" try: # Check if we have data to plot if not stats or len(stats.get('episode_rewards', [])) == 0: logger.warning("No training data to plot") return # Create a DataFrame with consistent lengths max_len = max(len(stats.get(key, [])) for key in stats) # Ensure all arrays have the same length by padding with the last value or zeros processed_stats = {} for key, values in stats.items(): if not values: # Skip empty lists continue # Pad arrays to the same length if len(values) < max_len: if len(values) > 0: # Pad with the last value values = values + [values[-1]] * (max_len - len(values)) else: # Pad with zeros values = [0] * max_len processed_stats[key] = values[:max_len] # Trim if longer # Create DataFrame df = pd.DataFrame(processed_stats) # Add episode column df['episode'] = range(1, len(df) + 1) # Create figure with subplots fig, axes = plt.subplots(3, 2, figsize=(15, 15)) # Plot episode rewards if 'episode_rewards' in df.columns: axes[0, 0].plot(df['episode'], df['episode_rewards']) axes[0, 0].set_title('Episode Rewards') axes[0, 0].set_xlabel('Episode') axes[0, 0].set_ylabel('Reward') axes[0, 0].grid(True) # Plot account balance if 'balances' in df.columns: axes[0, 1].plot(df['episode'], df['balances']) axes[0, 1].set_title('Account Balance') axes[0, 1].set_xlabel('Episode') axes[0, 1].set_ylabel('Balance ($)') axes[0, 1].grid(True) # Plot win rate if 'win_rates' in df.columns: axes[1, 0].plot(df['episode'], df['win_rates']) axes[1, 0].set_title('Win Rate') axes[1, 0].set_xlabel('Episode') axes[1, 0].set_ylabel('Win Rate') axes[1, 0].set_ylim([0, 1]) axes[1, 0].grid(True) # Plot episode PnL if 'episode_pnls' in df.columns: axes[1, 1].plot(df['episode'], df['episode_pnls']) axes[1, 1].set_title('Episode PnL') axes[1, 1].set_xlabel('Episode') axes[1, 1].set_ylabel('PnL ($)') axes[1, 1].grid(True) # Plot cumulative PnL if 'cumulative_pnl' in df.columns: axes[2, 0].plot(df['episode'], df['cumulative_pnl']) axes[2, 0].set_title('Cumulative PnL') axes[2, 0].set_xlabel('Episode') axes[2, 0].set_ylabel('Cumulative PnL ($)') axes[2, 0].grid(True) # Plot maximum drawdown if 'drawdowns' in df.columns: axes[2, 1].plot(df['episode'], df['drawdowns']) axes[2, 1].set_title('Maximum Drawdown') axes[2, 1].set_xlabel('Episode') axes[2, 1].set_ylabel('Drawdown') axes[2, 1].grid(True) # Adjust layout plt.tight_layout() # Save figure plt.savefig('training_results.png') logger.info("Training results saved to training_results.png") # Save statistics to CSV df.to_csv('training_stats.csv', index=False) logger.info("Training statistics saved to training_stats.csv") except Exception as e: logger.error(f"Error plotting training results: {e}") logger.error(traceback.format_exc()) def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" total_reward = 0 total_profit = 0 total_trades = 0 winning_trades = 0 for episode in range(num_episodes): state = env.reset() episode_reward = 0 initial_balance = env.balance done = False while not done: # Select action (no exploration) action = agent.select_action(state, training=False) next_state, reward, done, info = env.step(action) state = next_state episode_reward += reward total_reward += episode_reward total_profit += env.balance - initial_balance # Count trades and wins for trade in env.trades: if 'pnl_percent' in trade: total_trades += 1 if trade['pnl_percent'] > 0: winning_trades += 1 # Calculate averages avg_reward = total_reward / num_episodes avg_profit = total_profit / num_episodes win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0 logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, " f"Win Rate={win_rate:.1f}%") return avg_reward, avg_profit, win_rate async def test_training(): """Test the training process with a small number of episodes""" logger.info("Starting training tests...") # Initialize exchange exchange = ccxt.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True, }) try: # Create environment with small initial balance for testing env = TradingEnvironment( exchange=exchange, symbol="ETH/USDT", timeframe="1m", leverage=MAX_LEVERAGE, initial_balance=100, # Small balance for testing demo=True # Always use demo mode for testing ) # Fetch initial data await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) # Create agent agent = Agent(state_size=STATE_SIZE, action_size=env.action_space) # Run a few test episodes test_episodes = 3 logger.info(f"Running {test_episodes} test episodes...") for episode in range(test_episodes): state = env.reset() episode_reward = 0 done = False step = 0 while not done and step < 100: # Limit steps for testing # Select action action = agent.select_action(state) # Take action next_state, reward, done, info = env.step(action) # Store experience agent.memory.push(state, action, reward, next_state, done) # Learn loss = agent.learn() state = next_state episode_reward += reward step += 1 # Print progress if step % 10 == 0: logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}") logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}") # Test model saving try: agent.save("models/test_model.pt") logger.info("Successfully saved model") except Exception as e: logger.error(f"Error saving model: {e}") logger.info("Training tests completed successfully") return True except Exception as e: logger.error(f"Training test failed: {e}") return False finally: await exchange.close() async def initialize_exchange(): """Initialize the exchange connection""" try: # Try to initialize with async support first try: exchange = ccxt.pro.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True }) logger.info(f"Exchange initialized with async support: {exchange.id}") except (AttributeError, ImportError): # Fall back to standard CCXT exchange = ccxt.mexc({ 'apiKey': MEXC_API_KEY, 'secret': MEXC_SECRET_KEY, 'enableRateLimit': True }) logger.info(f"Exchange initialized with standard CCXT: {exchange.id}") return exchange except Exception as e: logger.error(f"Failed to initialize exchange: {e}") raise async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch historical OHLCV data from the exchange""" try: logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}") # Use the refactored fetch method data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) if not data: logger.warning("No historical data received") return data except Exception as e: logger.error(f"Failed to fetch historical data: {e}") return [] async def live_trading( symbol="ETH/USDT", timeframe="1m", model_path=None, demo=False, leverage=50, initial_balance=1000, max_position_size=0.1, commission=0.0004, window_size=30, update_interval=60, stop_loss_pct=0.02, take_profit_pct=0.04, max_trades_per_day=10, risk_per_trade=0.02, use_trailing_stop=False, trailing_stop_callback=0.005, use_dynamic_sizing=True, use_volatility_sizing=True, use_multi_timeframe=True, use_sentiment=False, use_limit_orders=False, use_dollar_cost_avg=False, use_grid_trading=False, use_martingale=False, use_anti_martingale=False, use_custom_indicators=True, use_ml_predictions=True, use_ensemble=True, use_reinforcement=True, use_risk_management=True, use_portfolio_management=False, use_position_sizing=True, use_stop_loss=True, use_take_profit=True, use_trailing_stop_loss=False, use_dynamic_stop_loss=True, use_dynamic_take_profit=True, use_dynamic_trailing_stop=False, use_dynamic_position_sizing=True, use_dynamic_leverage=False, use_dynamic_risk_per_trade=True, use_dynamic_max_trades_per_day=False, use_dynamic_update_interval=False, use_dynamic_window_size=False, use_dynamic_commission=False, use_dynamic_timeframe=False, use_dynamic_symbol=False, use_dynamic_model_path=False, use_dynamic_demo=False, use_dynamic_leverage_value=False, use_dynamic_initial_balance=False, use_dynamic_max_position_size=False, use_dynamic_stop_loss_pct=False, use_dynamic_take_profit_pct=False, use_dynamic_risk_per_trade_value=False, use_dynamic_trailing_stop_callback=False, use_dynamic_use_trailing_stop=False, use_dynamic_use_dynamic_sizing=False, use_dynamic_use_volatility_sizing=False, use_dynamic_use_multi_timeframe=False, use_dynamic_use_sentiment=False, use_dynamic_use_limit_orders=False, use_dynamic_use_dollar_cost_avg=False, use_dynamic_use_grid_trading=False, use_dynamic_use_martingale=False, use_dynamic_use_anti_martingale=False, use_dynamic_use_custom_indicators=False, use_dynamic_use_ml_predictions=False, use_dynamic_use_ensemble=False, use_dynamic_use_reinforcement=False, use_dynamic_use_risk_management=False, use_dynamic_use_portfolio_management=False, use_dynamic_use_position_sizing=False, use_dynamic_use_stop_loss=False, use_dynamic_use_take_profit=False, use_dynamic_use_trailing_stop_loss=False, use_dynamic_use_dynamic_stop_loss=False, use_dynamic_use_dynamic_take_profit=False, use_dynamic_use_dynamic_trailing_stop=False, use_dynamic_use_dynamic_position_sizing=False, use_dynamic_use_dynamic_leverage=False, use_dynamic_use_dynamic_risk_per_trade=False, use_dynamic_use_dynamic_max_trades_per_day=False, use_dynamic_use_dynamic_update_interval=False, use_dynamic_use_dynamic_window_size=False, use_dynamic_use_dynamic_commission=False, use_dynamic_use_dynamic_timeframe=False, use_dynamic_use_dynamic_symbol=False, use_dynamic_use_dynamic_model_path=False, use_dynamic_use_dynamic_demo=False, use_dynamic_use_dynamic_leverage_value=False, use_dynamic_use_dynamic_initial_balance=False, use_dynamic_use_dynamic_max_position_size=False, use_dynamic_use_dynamic_stop_loss_pct=False, use_dynamic_use_dynamic_take_profit_pct=False, use_dynamic_use_dynamic_risk_per_trade_value=False, use_dynamic_use_dynamic_trailing_stop_callback=False, ): """ Live trading function that connects to the exchange and trades in real-time. Args: symbol: Trading pair symbol timeframe: Timeframe for trading model_path: Path to the trained model demo: Whether to use demo mode (sandbox) leverage: Leverage to use initial_balance: Initial balance max_position_size: Maximum position size as a percentage of balance commission: Commission rate window_size: Window size for the environment update_interval: Interval to update data in seconds stop_loss_pct: Stop loss percentage take_profit_pct: Take profit percentage max_trades_per_day: Maximum trades per day risk_per_trade: Risk per trade as a percentage of balance use_trailing_stop: Whether to use trailing stop trailing_stop_callback: Trailing stop callback percentage use_dynamic_sizing: Whether to use dynamic position sizing use_volatility_sizing: Whether to use volatility-based position sizing use_multi_timeframe: Whether to use multi-timeframe analysis use_sentiment: Whether to use sentiment analysis use_limit_orders: Whether to use limit orders use_dollar_cost_avg: Whether to use dollar cost averaging use_grid_trading: Whether to use grid trading use_martingale: Whether to use martingale strategy use_anti_martingale: Whether to use anti-martingale strategy use_custom_indicators: Whether to use custom indicators use_ml_predictions: Whether to use ML predictions use_ensemble: Whether to use ensemble methods use_reinforcement: Whether to use reinforcement learning use_risk_management: Whether to use risk management use_portfolio_management: Whether to use portfolio management use_position_sizing: Whether to use position sizing use_stop_loss: Whether to use stop loss use_take_profit: Whether to use take profit use_trailing_stop_loss: Whether to use trailing stop loss use_dynamic_stop_loss: Whether to use dynamic stop loss use_dynamic_take_profit: Whether to use dynamic take profit use_dynamic_trailing_stop: Whether to use dynamic trailing stop use_dynamic_position_sizing: Whether to use dynamic position sizing use_dynamic_leverage: Whether to use dynamic leverage use_dynamic_risk_per_trade: Whether to use dynamic risk per trade use_dynamic_max_trades_per_day: Whether to use dynamic max trades per day use_dynamic_update_interval: Whether to use dynamic update interval use_dynamic_window_size: Whether to use dynamic window size use_dynamic_commission: Whether to use dynamic commission use_dynamic_timeframe: Whether to use dynamic timeframe use_dynamic_symbol: Whether to use dynamic symbol use_dynamic_model_path: Whether to use dynamic model path use_dynamic_demo: Whether to use dynamic demo use_dynamic_leverage_value: Whether to use dynamic leverage value use_dynamic_initial_balance: Whether to use dynamic initial balance use_dynamic_max_position_size: Whether to use dynamic max position size use_dynamic_stop_loss_pct: Whether to use dynamic stop loss percentage use_dynamic_take_profit_pct: Whether to use dynamic take profit percentage use_dynamic_risk_per_trade_value: Whether to use dynamic risk per trade value use_dynamic_trailing_stop_callback: Whether to use dynamic trailing stop callback """ logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe") logger.info(f"Demo mode: {demo}, Leverage: {leverage}x") # Flag to track if we're using mock trading using_mock_trading = False # Initialize exchange try: exchange = await initialize_exchange() # Try to set sandbox mode if demo is True if demo: try: exchange.set_sandbox_mode(demo) logger.info(f"Sandbox mode set to {demo}") except Exception as e: logger.warning(f"Exchange doesn't support sandbox mode: {e}") logger.info("Continuing in mock trading mode instead") using_mock_trading = True # Set leverage if not demo or using_mock_trading: try: await exchange.set_leverage(leverage, symbol) logger.info(f"Leverage set to {leverage}x") except Exception as e: logger.warning(f"Failed to set leverage: {e}") # Initialize environment env = TradingEnvironment( initial_balance=initial_balance, leverage=leverage, window_size=window_size, commission=commission, symbol=symbol, timeframe=timeframe, max_position_size=max_position_size, stop_loss_pct=stop_loss_pct, take_profit_pct=take_profit_pct, max_trades_per_day=max_trades_per_day, risk_per_trade=risk_per_trade, use_trailing_stop=use_trailing_stop, trailing_stop_callback=trailing_stop_callback, use_dynamic_sizing=use_dynamic_sizing, use_volatility_sizing=use_volatility_sizing, use_multi_timeframe=use_multi_timeframe, use_sentiment=use_sentiment, use_limit_orders=use_limit_orders, use_dollar_cost_avg=use_dollar_cost_avg, use_grid_trading=use_grid_trading, use_martingale=use_martingale, use_anti_martingale=use_anti_martingale, use_custom_indicators=use_custom_indicators, use_ml_predictions=use_ml_predictions, use_ensemble=use_ensemble, use_reinforcement=use_reinforcement, use_risk_management=use_risk_management, use_portfolio_management=use_portfolio_management, use_position_sizing=use_position_sizing, use_stop_loss=use_stop_loss, use_take_profit=use_take_profit, use_trailing_stop_loss=use_trailing_stop_loss, use_dynamic_stop_loss=use_dynamic_stop_loss, use_dynamic_take_profit=use_dynamic_take_profit, use_dynamic_trailing_stop=use_dynamic_trailing_stop, use_dynamic_position_sizing=use_dynamic_position_sizing, use_dynamic_leverage=use_dynamic_leverage, use_dynamic_risk_per_trade=use_dynamic_risk_per_trade, use_dynamic_max_trades_per_day=use_dynamic_max_trades_per_day, use_dynamic_update_interval=use_dynamic_update_interval, use_dynamic_window_size=use_dynamic_window_size, use_dynamic_commission=use_dynamic_commission, use_dynamic_timeframe=use_dynamic_timeframe, use_dynamic_symbol=use_dynamic_symbol, use_dynamic_model_path=use_dynamic_model_path, use_dynamic_demo=use_dynamic_demo, use_dynamic_leverage_value=use_dynamic_leverage_value, use_dynamic_initial_balance=use_dynamic_initial_balance, use_dynamic_max_position_size=use_dynamic_max_position_size, use_dynamic_stop_loss_pct=use_dynamic_stop_loss_pct, use_dynamic_take_profit_pct=use_dynamic_take_profit_pct, use_dynamic_risk_per_trade_value=use_dynamic_risk_per_trade_value, use_dynamic_trailing_stop_callback=use_dynamic_trailing_stop_callback, use_dynamic_use_trailing_stop=use_dynamic_use_trailing_stop, use_dynamic_use_dynamic_sizing=use_dynamic_use_dynamic_sizing, use_dynamic_use_volatility_sizing=use_dynamic_use_volatility_sizing, use_dynamic_use_multi_timeframe=use_dynamic_use_multi_timeframe, use_dynamic_use_sentiment=use_dynamic_use_sentiment, use_dynamic_use_limit_orders=use_dynamic_use_limit_orders, use_dynamic_use_dollar_cost_avg=use_dynamic_use_dollar_cost_avg, use_dynamic_use_grid_trading=use_dynamic_use_grid_trading, use_dynamic_use_martingale=use_dynamic_use_martingale, use_dynamic_use_anti_martingale=use_dynamic_use_anti_martingale, use_dynamic_use_custom_indicators=use_dynamic_use_custom_indicators, use_dynamic_use_ml_predictions=use_dynamic_use_ml_predictions, use_dynamic_use_ensemble=use_dynamic_use_ensemble, use_dynamic_use_reinforcement=use_dynamic_use_reinforcement, use_dynamic_use_risk_management=use_dynamic_use_risk_management, use_dynamic_use_portfolio_management=use_dynamic_use_portfolio_management, use_dynamic_use_position_sizing=use_dynamic_use_position_sizing, use_dynamic_use_stop_loss=use_dynamic_use_stop_loss, use_dynamic_use_take_profit=use_dynamic_use_take_profit, use_dynamic_use_trailing_stop_loss=use_dynamic_use_trailing_stop_loss, use_dynamic_use_dynamic_stop_loss=use_dynamic_use_dynamic_stop_loss, use_dynamic_use_dynamic_take_profit=use_dynamic_use_dynamic_take_profit, use_dynamic_use_dynamic_trailing_stop=use_dynamic_use_dynamic_trailing_stop, use_dynamic_use_dynamic_position_sizing=use_dynamic_use_dynamic_position_sizing, use_dynamic_use_dynamic_leverage=use_dynamic_use_dynamic_leverage, use_dynamic_use_dynamic_risk_per_trade=use_dynamic_use_dynamic_risk_per_trade, use_dynamic_use_dynamic_max_trades_per_day=use_dynamic_use_dynamic_max_trades_per_day, use_dynamic_use_dynamic_update_interval=use_dynamic_use_dynamic_update_interval, use_dynamic_use_dynamic_window_size=use_dynamic_use_dynamic_window_size, use_dynamic_use_dynamic_commission=use_dynamic_use_dynamic_commission, use_dynamic_use_dynamic_timeframe=use_dynamic_use_dynamic_timeframe, use_dynamic_use_dynamic_symbol=use_dynamic_use_dynamic_symbol, use_dynamic_use_dynamic_model_path=use_dynamic_use_dynamic_model_path, use_dynamic_use_dynamic_demo=use_dynamic_use_dynamic_demo, use_dynamic_use_dynamic_leverage_value=use_dynamic_use_dynamic_leverage_value, use_dynamic_use_dynamic_initial_balance=use_dynamic_use_dynamic_initial_balance, use_dynamic_use_dynamic_max_position_size=use_dynamic_use_dynamic_max_position_size, use_dynamic_use_dynamic_stop_loss_pct=use_dynamic_use_dynamic_stop_loss_pct, use_dynamic_use_dynamic_take_profit_pct=use_dynamic_use_dynamic_take_profit_pct, use_dynamic_use_dynamic_risk_per_trade_value=use_dynamic_use_dynamic_risk_per_trade_value, use_dynamic_use_dynamic_trailing_stop_callback=use_dynamic_use_dynamic_trailing_stop_callback, ) # Fetch initial data logger.info(f"Fetching initial data for {symbol}") await fetch_and_update_data(exchange, env, symbol, timeframe) # Initialize agent STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64 ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4 agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384) # Load model if provided if model_path: agent.load(model_path) logger.info(f"Model loaded successfully from {model_path}") # Initialize TensorBoard writer agent.writer = SummaryWriter(log_dir=f"runs/live_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # Initialize trading statistics trades = [] total_pnl = 0 win_count = 0 loss_count = 0 # Initialize trading log file log_file = f"live_trading_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" with open(log_file, 'w') as f: f.write("timestamp,action,price,position_size,balance,pnl\n") # Start live trading loop logger.info(f"Starting live trading with {symbol} on {timeframe} timeframe") # Main trading loop step_counter = 0 last_update_time = time.time() while True: # Get current state state = env.get_state() # Select action action = agent.select_action(state, training=False) # Take action next_state, reward, done, info = env.step(action) # Log action and results if info.get('trade_executed', False): trade_data = { 'timestamp': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'action': info['action'], 'price': env.current_price, 'position_size': env.position_size, 'balance': env.balance, 'pnl': env.last_trade_profit } trades.append(trade_data) # Update statistics if env.last_trade_profit > 0: win_count += 1 total_pnl += env.last_trade_profit else: loss_count += 1 # Log trade to file with open(log_file, 'a') as f: f.write(f"{trade_data['timestamp']},{trade_data['action']},{trade_data['price']},{trade_data['position_size']},{trade_data['balance']},{trade_data['pnl']}\n") logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") # Update TensorBoard metrics if step_counter % 10 == 0: agent.writer.add_scalar('Live/Balance', env.balance, step_counter) agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) agent.writer.add_scalar('Live/Reward', reward, step_counter) # Check if it's time to update data current_time = time.time() if current_time - last_update_time > update_interval: await fetch_and_update_data(exchange, env, symbol, timeframe) last_update_time = current_time # Print status update win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0 logger.info(f""" Step: {step_counter} Balance: ${env.balance:.2f} Total PnL: ${env.total_pnl:.2f} Win Rate: {win_rate:.2f} Trades: {len(trades)} """) # Move to next state state = next_state step_counter += 1 # Sleep to avoid excessive API calls await asyncio.sleep(1) # Check for manual stop if done: break # Close TensorBoard writer agent.writer.close() # Save final statistics win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0 logger.info(f""" Live Trading Summary: Total Steps: {step_counter} Final Balance: ${env.balance:.2f} Total PnL: ${env.total_pnl:.2f} Win Rate: {win_rate:.2f} Total Trades: {len(trades)} """) # Close exchange connection try: await exchange.close() logger.info("Exchange connection closed") except Exception as e: logger.warning(f"Error closing exchange connection: {e}") except Exception as e: logger.error(f"Error in live trading: {e}") logger.error(traceback.format_exc()) try: await exchange.close() except: pass logger.info("Exchange connection closed") async def get_latest_candle(exchange, symbol): """ Get the latest candle for a symbol. Args: exchange: Exchange instance symbol: Trading pair symbol Returns: Latest candle data or None on failure """ try: # Use the refactored fetch method with limit=1 data = await fetch_ohlcv_data(exchange, symbol, "1m", 1) if data and len(data) > 0: return data[0] else: logger.warning("No candle data received") return None except Exception as e: logger.error(f"Failed to fetch latest candle: {e}") return None async def fetch_ohlcv_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """ Fetch OHLCV data from exchange with error handling and retry logic. Args: exchange: The exchange instance symbol: Trading pair symbol timeframe: Candle timeframe limit: Number of candles to fetch Returns: List of candle data or empty list on failure """ max_retries = 3 retry_delay = 5 for attempt in range(max_retries): try: logging.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt {attempt+1}/{max_retries})") # Check if exchange has fetch_ohlcv method if not hasattr(exchange, 'fetch_ohlcv'): logging.error("Exchange does not support OHLCV data fetching") return [] # Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor try: if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, limit=limit) else: # Run in executor to avoid blocking loop = asyncio.get_event_loop() ohlcv = await loop.run_in_executor( None, lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) ) except Exception as e: logging.error(f"Failed to fetch OHLCV data: {e}") await asyncio.sleep(retry_delay) continue if not ohlcv or len(ohlcv) == 0: logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})") await asyncio.sleep(retry_delay) continue # Convert to list of lists format data = [] for candle in ohlcv: timestamp, open_price, high, low, close, volume = candle data.append([timestamp, open_price, high, low, close, volume]) logging.info(f"Successfully fetched {len(data)} candles") return data except Exception as e: logging.error(f"Error fetching OHLCV data (attempt {attempt+1}/{max_retries}): {e}") if attempt < max_retries - 1: await asyncio.sleep(retry_delay) logging.error(f"Failed to fetch OHLCV data after {max_retries} attempts") return [] # Add this near the top of the file, after imports def ensure_pytorch_compatibility(): """Ensure compatibility with PyTorch 2.6+ for model loading""" try: import torch from torch.serialization import add_safe_globals import numpy as np # Add numpy scalar to safe globals for PyTorch 2.6+ add_safe_globals(['numpy._core.multiarray.scalar']) logger.info("Added numpy scalar to PyTorch safe globals") except (ImportError, AttributeError) as e: logger.warning(f"Could not configure PyTorch compatibility: {e}") logger.warning("This might cause issues with model loading in PyTorch 2.6+") # Call this function at the start of the main function async def main(): # Ensure PyTorch compatibility ensure_pytorch_compatibility() parser = argparse.ArgumentParser(description='Trading Bot') parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train', help='Operation mode: train, eval, or live') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training or evaluation') parser.add_argument('--max_steps', type=int, default=1000, help='Maximum steps per episode for training') parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true', help='Run in demo mode (paper trading) if true') parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe (1m, 5m, 15m, 1h, etc.)') parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading') parser.add_argument('--model', type=str, default='models/trading_agent_best_net_pnl.pt', help='Path to model file for evaluation or live trading') parser.add_argument('--compact_save', action='store_true', help='Use compact model saving (for low disk space)') args = parser.parse_args() # Convert string boolean to actual boolean demo_mode = args.demo.lower() == 'true' # Get device (GPU or CPU) device = get_device() exchange = None try: # Initialize exchange exchange = await initialize_exchange() # Create environment with updated parameters env = TradingEnvironment( initial_balance=INITIAL_BALANCE, window_size=30, leverage=args.leverage, exchange_id='mexc', symbol=args.symbol, timeframe=args.timeframe ) if args.mode == 'train': # Fetch initial data for training await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000) # Create agent with consistent parameters # Note: Using STATE_SIZE and action_size=4 for consistency agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) # Train the agent logger.info(f"Starting training for {args.episodes} episodes...") stats = await train_agent(agent, env, num_episodes=args.episodes, max_steps_per_episode=args.max_steps, use_compact_save=args.compact_save) elif args.mode == 'eval' or args.mode == 'live': # Fetch initial data for the specified symbol and timeframe await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000) # Determine model path model_path = args.model if args.model else "models/trading_agent_best_pnl.pt" if not os.path.exists(model_path): logger.error(f"Model file not found: {model_path}") return # Create agent with default parameters agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) # Try to load the model try: # Add numpy scalar to safe globals before loading import numpy as np from torch.serialization import add_safe_globals # Add numpy scalar to safe globals add_safe_globals(['numpy._core.multiarray.scalar']) # Load the model agent.load(model_path) logger.info(f"Model loaded successfully from {model_path}") except Exception as e: logger.error(f"Failed to load model: {e}") # Ask user if they want to continue with a new model if args.mode == 'live': confirmation = input("Failed to load model. Continue with a new model? (y/n): ") if confirmation.lower() != 'y': logger.info("Live trading canceled by user") return logger.info("Continuing with a new model") else: logger.info("Continuing evaluation with a new model") if args.mode == 'eval': # Evaluate the agent logger.info("Evaluating agent...") avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes) elif args.mode == 'live': # Start live trading logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe") logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x") await live_trading( agent=agent, env=env, exchange=exchange, symbol=args.symbol, timeframe=args.timeframe, demo=demo_mode, leverage=args.leverage ) except Exception as e: logger.error(f"Error in main function: {e}") import traceback logger.error(traceback.format_exc()) finally: # Clean up exchange connection if exchange: try: if hasattr(exchange, 'close'): await exchange.close() elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'): await exchange.client.close() logger.info("Exchange connection closed") except Exception as e: logger.warning(f"Could not properly close exchange connection: {e}") # Add this function near the top with other utility functions def create_candlestick_figure(data, trades=None, title="Trading Chart"): """Create a candlestick chart with trades marked""" try: if data is None or len(data) < 5: logger.warning("Not enough data for candlestick chart") return None # Convert data to DataFrame if it's not already if not isinstance(data, pd.DataFrame): df = pd.DataFrame(data) else: df = data.copy() # Ensure required columns exist required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] for col in required_columns: if col not in df.columns: logger.warning(f"Missing required column {col} for candlestick chart") return None # Format dates if 'timestamp' in df.columns: if isinstance(df['timestamp'].iloc[0], (int, float)): # Convert timestamp to datetime if it's numeric df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') # Set timestamp as index if it's not already if df.index.name != 'timestamp': df.set_index('timestamp', inplace=True) # Rename columns for mplfinance df_mpf = df.copy() if 'open' in df_mpf.columns: df_mpf = df_mpf.rename(columns={ 'open': 'Open', 'high': 'High', 'low': 'Low', 'close': 'Close', 'volume': 'Volume' }) # Create a simple matplotlib figure instead fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [3, 1]}) # Plot candlesticks manually for i in range(len(df_mpf)): # Get date and prices date = df_mpf.index[i] open_price = df_mpf['Open'].iloc[i] high_price = df_mpf['High'].iloc[i] low_price = df_mpf['Low'].iloc[i] close_price = df_mpf['Close'].iloc[i] # Determine color based on price movement color = 'green' if close_price >= open_price else 'red' # Plot candle body body_height = abs(close_price - open_price) body_bottom = min(close_price, open_price) ax1.bar(date, body_height, bottom=body_bottom, width=0.6, color=color, alpha=0.6) # Plot wick ax1.plot([date, date], [low_price, high_price], color=color, linewidth=1) # Plot volume ax2.bar(df_mpf.index, df_mpf['Volume'], width=0.6, color='blue', alpha=0.5) # Mark trades if available if trades and len(trades) > 0: for trade in trades: if 'timestamp' not in trade or 'type' not in trade or 'price' not in trade: continue # Convert timestamp to datetime if needed if isinstance(trade['timestamp'], (int, float)): trade_time = pd.to_datetime(trade['timestamp'], unit='ms') else: trade_time = trade['timestamp'] # Determine marker color based on trade type marker_color = 'green' if trade['type'].lower() == 'buy' else 'red' # Add marker at trade price ax1.scatter(trade_time, trade['price'], color=marker_color, marker='^' if trade['type'].lower() == 'buy' else 'v', s=100, zorder=5) # Set title and labels ax1.set_title(title) ax1.set_ylabel('Price') ax2.set_ylabel('Volume') ax1.grid(True) ax2.grid(True) # Format x-axis plt.setp(ax1.get_xticklabels(), visible=False) # Adjust layout plt.tight_layout() return fig except Exception as e: logger.error(f"Error creating candlestick figure: {e}") return None class CandlePatternCNN(nn.Module): """Convolutional neural network for detecting candlestick patterns""" def __init__(self, input_channels=5, feature_dimension=512): super(CandlePatternCNN, self).__init__() self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.relu3 = nn.ReLU() self.pool3 = nn.MaxPool2d(kernel_size=2) # Projection layers self.fc1 = nn.Linear(128 * 4 * 4, 1024) self.relu4 = nn.ReLU() self.fc2 = nn.Linear(1024, feature_dimension) # Initialize intermediate features as empty tensors, not as a dict # This makes the model TorchScript compatible self.feature_1m = torch.zeros(1, feature_dimension) self.feature_1h = torch.zeros(1, feature_dimension) self.feature_1d = torch.zeros(1, feature_dimension) def forward(self, x_1m, x_1h, x_1d): # Process 1m data feat_1m = self.process_timeframe(x_1m) # Process 1h data feat_1h = self.process_timeframe(x_1h) # Process 1d data feat_1d = self.process_timeframe(x_1d) # Store features as attributes instead of in a dictionary self.feature_1m = feat_1m self.feature_1h = feat_1h self.feature_1d = feat_1d # Concatenate features from different timeframes combined_features = torch.cat([feat_1m, feat_1h, feat_1d], dim=1) return combined_features def process_timeframe(self, x): """Process a single timeframe batch of data""" # Ensure proper shape for input, handle both batched and single inputs if len(x.shape) == 3: # Single input, shape: [channels, height, width] x = x.unsqueeze(0) # Add batch dimension x = self.pool1(self.relu1(self.conv1(x))) x = self.pool2(self.relu2(self.conv2(x))) x = self.pool3(self.relu3(self.conv3(x))) # Flatten the spatial dimensions for the fully connected layer x = x.view(x.size(0), -1) x = self.relu4(self.fc1(x)) x = self.fc2(x) return x def get_features(self): """Return features for each timeframe""" # Use properties instead of dict for TorchScript compatibility return self.feature_1m, self.feature_1h, self.feature_1d # Add candle cache system class CandleCache: """ Cache system for candles of different timeframes. Reduces API calls by storing and updating candle data. """ def __init__(self): self.candles = { '1m': [], '1h': [], '1d': [] } self.last_updated = { '1m': None, '1h': None, '1d': None } # Add ticks channel for real-time data (WebSocket) self.ticks = [] self.last_tick_time = None def add_candles(self, timeframe, new_candles): """Add new candles to the cache""" if not self.candles[timeframe]: self.candles[timeframe] = new_candles else: # Find the last timestamp in our current cache last_timestamp = self.candles[timeframe][-1][0] # Add only candles newer than our last cached one for candle in new_candles: if candle[0] > last_timestamp: self.candles[timeframe].append(candle) self.last_updated[timeframe] = datetime.datetime.now() def add_tick(self, tick_data): """Add a new tick to the ticks buffer""" self.ticks.append(tick_data) self.last_tick_time = datetime.datetime.now() # Keep only the most recent 1000 ticks to prevent memory issues if len(self.ticks) > 1000: self.ticks = self.ticks[-1000:] def get_ticks(self, limit=None): """Get the most recent ticks from the buffer""" if not self.ticks: return [] if limit and limit > 0: return self.ticks[-limit:] return self.ticks def get_candles(self, timeframe, limit=300): """Get the most recent candles for a timeframe""" if not self.candles[timeframe]: return [] return self.candles[timeframe][-limit:] def needs_update(self, timeframe, max_age_seconds): """Check if the cache needs to be updated""" if not self.last_updated[timeframe]: return True age = (datetime.datetime.now() - self.last_updated[timeframe]).total_seconds() return age > max_age_seconds async def fetch_multi_timeframe_data(exchange, symbol, candle_cache): """Fetch candle data for multiple timeframes, using cache when possible""" update_intervals = { '1m': 60, # Update every 1 minute '1h': 3600, # Update every 1 hour '1d': 86400 # Update every 1 day } # TODO: For 1s/tick timeframes, we'll implement the exchange's WebSocket API # for real-time data streaming in the future. This will enable ultra-low latency # trading signals with minimal delay between market data reception and action execution. # A WebSocket implementation is already prepared in the RealTimeDataStream class. limits = { '1m': 1000, '1h': 500, '1d': 300 } for timeframe, interval in update_intervals.items(): if candle_cache.needs_update(timeframe, interval): try: logging.info(f"Fetching {timeframe} candle data for {symbol}") candles = await fetch_ohlcv_data(exchange, symbol, timeframe, limits[timeframe]) candle_cache.add_candles(timeframe, candles) logging.info(f"Fetched {len(candles)} {timeframe} candles") except Exception as e: logging.error(f"Error fetching {timeframe} candle data: {e}") return { '1m': candle_cache.get_candles('1m'), '1h': candle_cache.get_candles('1h'), '1d': candle_cache.get_candles('1d') } # Modify the LSTMAttentionDQN class to incorporate the CNN features class LSTMAttentionDQN(nn.Module): def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(LSTMAttentionDQN, self).__init__() self.state_size = state_size self.action_size = action_size self.hidden_size = hidden_size self.lstm_layers = lstm_layers self.attention_heads = attention_heads # LSTM layer self.lstm = nn.LSTM( input_size=state_size, hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2 if lstm_layers > 1 else 0 ) # Multi-head self-attention self.attention = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=attention_heads, dropout=0.1 ) # Value stream self.value_stream = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Linear(128, 1) ) # Advantage stream self.advantage_stream = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Linear(128, action_size) ) # Fusion for multi-timeframe data self.cnn_fusion = nn.Sequential( nn.Linear(512 * 3, 1024), # 512 features from each of the 3 timeframes nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, hidden_size) ) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.LSTM): for name, param in module.named_parameters(): if 'weight' in name: nn.init.xavier_uniform_(param) elif 'bias' in name: nn.init.constant_(param, 0) def forward(self, state, x_1m=None, x_1h=None, x_1d=None): """ Forward pass handling different input shapes and optional CNN features Args: state: Primary state vector (batch_size, sequence_length, state_size) x_1m, x_1h, x_1d: Optional CNN features from different timeframes Returns: Q-values for each action """ batch_size = state.size(0) # Handle CNN features if provided if x_1m is not None and x_1h is not None and x_1d is not None: # Ensure all CNN features have batch dimension if len(x_1m.shape) == 2: x_1m = x_1m.unsqueeze(0) if len(x_1h.shape) == 2: x_1h = x_1h.unsqueeze(0) if len(x_1d.shape) == 2: x_1d = x_1d.unsqueeze(0) # Ensure batch dimensions match if x_1m.size(0) != batch_size: x_1m = x_1m.expand(batch_size, -1, -1) if x_1m.size(0) == 1 else x_1m[:batch_size] if x_1h.size(0) != batch_size: x_1h = x_1h.expand(batch_size, -1, -1) if x_1h.size(0) == 1 else x_1h[:batch_size] if x_1d.size(0) != batch_size: x_1d = x_1d.expand(batch_size, -1, -1) if x_1d.size(0) == 1 else x_1d[:batch_size] # Check dimensions before concatenation if x_1m.dim() == 3 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512: # Already in correct format [batch, features] cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1) elif x_1m.dim() == 2 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512: # Dimensions correct but missing batch dimension cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1).unsqueeze(0) else: # Reshape to ensure correct dimensions x_1m_flat = x_1m.reshape(batch_size, -1) x_1h_flat = x_1h.reshape(batch_size, -1) x_1d_flat = x_1d.reshape(batch_size, -1) # Handle variable dimensions more gracefully needed_features = 512 if x_1m_flat.size(1) < needed_features: x_1m_flat = F.pad(x_1m_flat, (0, needed_features - x_1m_flat.size(1))) else: x_1m_flat = x_1m_flat[:, :needed_features] if x_1h_flat.size(1) < needed_features: x_1h_flat = F.pad(x_1h_flat, (0, needed_features - x_1h_flat.size(1))) else: x_1h_flat = x_1h_flat[:, :needed_features] if x_1d_flat.size(1) < needed_features: x_1d_flat = F.pad(x_1d_flat, (0, needed_features - x_1d_flat.size(1))) else: x_1d_flat = x_1d_flat[:, :needed_features] # Concatenate cnn_combined = torch.cat([x_1m_flat, x_1h_flat, x_1d_flat], dim=1) # Use CNN fusion network to reduce dimension cnn_features = self.cnn_fusion(cnn_combined) # Reshape to match LSTM input shape cnn_features = cnn_features.view(batch_size, 1, self.hidden_size) # Combine with state input by concatenating along sequence dimension if state.dim() < 3: # If state is 2D [batch, features], reshape to 3D [batch, 1, features] state = state.unsqueeze(1) # Ensure state has proper dimensions if state.size(2) != self.state_size: # If state dimension doesn't match, reshape or pad if state.size(2) > self.state_size: state = state[:, :, :self.state_size] else: state = F.pad(state, (0, self.state_size - state.size(2))) # Concatenate along sequence dimension combined_input = torch.cat([state, cnn_features], dim=1) else: # Use only state input if CNN features not provided combined_input = state if combined_input.dim() < 3: # If state is 2D [batch, features], reshape to 3D [batch, 1, features] combined_input = combined_input.unsqueeze(1) # Ensure state has proper dimensions if combined_input.size(2) != self.state_size: # If state dimension doesn't match, reshape or pad if combined_input.size(2) > self.state_size: combined_input = combined_input[:, :, :self.state_size] else: combined_input = F.pad(combined_input, (0, self.state_size - combined_input.size(2))) # Pass through LSTM lstm_out, _ = self.lstm(combined_input) # Apply self-attention to LSTM output # Transform to shape required by MultiheadAttention (seq_len, batch, hidden) attn_input = lstm_out.transpose(0, 1) attn_output, _ = self.attention(attn_input, attn_input, attn_input) # Transform back to (batch, seq_len, hidden) attn_output = attn_output.transpose(0, 1) # Use last output after attention attn_out = attn_output[:, -1] # Value and advantage streams (dueling architecture) value = self.value_stream(attn_out) advantage = self.advantage_stream(attn_out) # Combine value and advantage for Q-values q_values = value + advantage - advantage.mean(dim=1, keepdim=True) return q_values def forward_realtime(self, x): """Simplified forward pass for realtime inference""" # Adapt x to the right format if needed if isinstance(x, np.ndarray): x = torch.FloatTensor(x) # Add batch dimension if not present if x.dim() == 1: x = x.unsqueeze(0) # Add sequence dimension if not present if x.dim() == 2: x = x.unsqueeze(1) # Basic forward pass lstm_out, _ = self.lstm(x) # Apply attention attn_input = lstm_out.transpose(0, 1) attn_output, _ = self.attention(attn_input, attn_input, attn_input) attn_output = attn_output.transpose(0, 1) # Get last output after attention features = attn_output[:, -1] # Value and advantage streams value = self.value_stream(features) advantage = self.advantage_stream(features) # Combine for Q-values q_values = value + advantage - advantage.mean(dim=1, keepdim=True) return q_values # Add this class after the CandleCache class class RealTimeDataStream: """ Class for handling WebSocket API connections for ultra-low latency trading signals. Provides real-time data streaming at 1-second intervals or faster for immediate trading decisions. """ def __init__(self, exchange, symbol, callback_fn=None): """ Initialize the real-time data stream with WebSocket connection Args: exchange: The exchange API client symbol: Trading pair symbol (e.g. 'ETH/USDT') callback_fn: Function to call when new data is received """ self.exchange = exchange self.symbol = symbol self.callback_fn = callback_fn self.websocket = None self.connected = False self.last_tick_time = None self.tick_buffer = [] self.latency_stats = [] self.logger = logging.getLogger(__name__) # Statistics for monitoring performance self.total_ticks = 0 self.avg_latency_ms = 0 self.max_latency_ms = 0 # Candle cache for storing processed data self.candle_cache = CandleCache() async def connect(self): """Connect to the exchange WebSocket API""" # TODO: Implement actual WebSocket connection logic self.logger.info(f"Connecting to WebSocket for {self.symbol}...") try: # This will be replaced with actual WebSocket connection code self.websocket = None # Placeholder self.connected = True self.logger.info(f"Connected to WebSocket for {self.symbol}") return True except Exception as e: self.logger.error(f"WebSocket connection error: {e}") return False async def subscribe(self): """Subscribe to relevant data channels""" # TODO: Implement actual WebSocket subscription logic self.logger.info(f"Subscribing to {self.symbol} ticks...") try: # This will be replaced with actual subscription code return True except Exception as e: self.logger.error(f"WebSocket subscription error: {e}") return False async def process_message(self, message): """ Process incoming WebSocket message Args: message: The raw WebSocket message Returns: Processed tick data """ # TODO: Implement actual WebSocket message processing logic try: # Track tick receipt time for latency calculations receive_time = time.time() * 1000 # milliseconds # This is a placeholder - actual implementation will parse the message # Example tick data structure (will vary by exchange): tick_data = { 'timestamp': receive_time, 'price': 0.0, # Will be replaced with actual price 'volume': 0.0, # Will be replaced with actual volume 'side': 'buy', # or 'sell' 'exchange_time': 0, # Will be replaced with exchange timestamp 'latency_ms': 0 # Will be calculated } # Calculate latency (difference between our receive time and exchange time) if 'exchange_time' in tick_data and tick_data['exchange_time'] > 0: latency = receive_time - tick_data['exchange_time'] tick_data['latency_ms'] = latency # Update latency statistics self.latency_stats.append(latency) if len(self.latency_stats) > 1000: self.latency_stats = self.latency_stats[-1000:] self.total_ticks += 1 self.avg_latency_ms = sum(self.latency_stats) / len(self.latency_stats) self.max_latency_ms = max(self.max_latency_ms, latency) # Store tick in buffer self.tick_buffer.append(tick_data) self.candle_cache.add_tick(tick_data) self.last_tick_time = datetime.datetime.now() # Keep buffer size reasonable if len(self.tick_buffer) > 1000: self.tick_buffer = self.tick_buffer[-1000:] # Call callback function if provided if self.callback_fn: await self.callback_fn(tick_data) return tick_data except Exception as e: self.logger.error(f"Error processing WebSocket message: {e}") return None def prepare_nn_input(self, model=None, state=None): """ Prepare network inputs from tick data for real-time inference Args: model: The neural network model state: Current state representation Returns: Prepared tensors for model input """ # Get the most recent ticks ticks = self.candle_cache.get_ticks(limit=300) if not ticks or len(ticks) < 10: # Not enough ticks for meaningful processing return None try: # Extract price and volume data from ticks prices = np.array([t['price'] for t in ticks if 'price' in t]) volumes = np.array([t['volume'] for t in ticks if 'volume' in t]) if len(prices) < 10: return None # Normalize data min_price, max_price = prices.min(), prices.max() price_range = max_price - min_price if price_range == 0: price_range = 1 normalized_prices = (prices - min_price) / price_range # Create tick tensor - this is flexible-length data # Format as sequence for time-series analysis tick_data = torch.FloatTensor(normalized_prices).unsqueeze(0).unsqueeze(0) return { 'state': state, 'ticks': tick_data } except Exception as e: self.logger.error(f"Error preparing neural network input: {e}") return None def get_latency_stats(self): """Get statistics about WebSocket connection latency""" return { 'total_ticks': self.total_ticks, 'avg_latency_ms': self.avg_latency_ms, 'max_latency_ms': self.max_latency_ms, 'last_update': self.last_tick_time.isoformat() if self.last_tick_time else None } async def close(self): """Close the WebSocket connection""" if self.connected and self.websocket: try: # This will be replaced with actual close logic self.connected = False self.logger.info(f"Closed WebSocket connection for {self.symbol}") return True except Exception as e: self.logger.error(f"Error closing WebSocket connection: {e}") return False class BacktestCandles(CandleCache): """ Special cache for backtesting that retrieves historical data from specific time periods without contaminating the main cache. Used for running simulations "as if" we were at a different point in time. """ def __init__(self, since_timestamp=None, until_timestamp=None): """ Initialize backtesting candle cache. Args: since_timestamp: Start timestamp for backtesting (milliseconds) until_timestamp: End timestamp for backtesting (milliseconds) """ super().__init__() # Since and until timestamps for backtesting self.since_timestamp = since_timestamp self.until_timestamp = until_timestamp # Flag to indicate this is a backtesting cache self.is_backtesting = True # Optional name for backtesting period (e.g., "Day 1 - 24h ago") self.period_name = None async def fetch_historical_timeframe(self, exchange, symbol, timeframe, limit=1000): """ Fetch historical data for a specific timeframe and time period. Args: exchange: The exchange instance symbol: Trading pair symbol timeframe: Candle timeframe limit: Number of candles to fetch Returns: Dictionary with candle data for the timeframe """ try: logging.info(f"Fetching historical {timeframe} candles for {symbol} " + f"(since: {self.format_timestamp(self.since_timestamp) if self.since_timestamp else 'None'}, " + f"until: {self.format_timestamp(self.until_timestamp) if self.until_timestamp else 'None'})") candles = await self.fetch_ohlcv_with_timerange(exchange, symbol, timeframe, limit, self.since_timestamp, self.until_timestamp) if candles: # Store in the appropriate timeframe self.candles[timeframe] = candles self.last_updated[timeframe] = datetime.datetime.now() logging.info(f"Fetched {len(candles)} historical {timeframe} candles for backtesting") else: logging.warning(f"No historical {timeframe} candles found for the specified time period") return candles except Exception as e: logging.error(f"Error fetching historical {timeframe} data: {e}") return [] async def fetch_all_timeframes(self, exchange, symbol): """ Fetch historical data for all timeframes. Args: exchange: The exchange instance symbol: Trading pair symbol Returns: Dictionary with candle data for all timeframes """ # Define limits for each timeframe limits = { '1m': 1000, '1h': 500, '1d': 300 } # Fetch data for each timeframe for timeframe, limit in limits.items(): await self.fetch_historical_timeframe(exchange, symbol, timeframe, limit) # Return the candles dictionary return { '1m': self.get_candles('1m'), '1h': self.get_candles('1h'), '1d': self.get_candles('1d') } async def fetch_ohlcv_with_timerange(self, exchange, symbol, timeframe, limit, since=None, until=None): """ Fetch OHLCV data within a specific time range. Args: exchange: The exchange instance symbol: Trading pair symbol timeframe: Candle timeframe limit: Number of candles to fetch since: Start timestamp (milliseconds) until: End timestamp (milliseconds) Returns: List of candle data """ max_retries = 3 retry_delay = 5 for attempt in range(max_retries): try: logging.info(f"Fetching {limit} {timeframe} candles for {symbol} " + f"(since: {self.format_timestamp(since) if since else 'None'}, " + f"until: {self.format_timestamp(until) if until else 'None'}) " + f"(attempt {attempt+1}/{max_retries})") # Check if exchange has fetch_ohlcv method if not hasattr(exchange, 'fetch_ohlcv'): logging.error("Exchange does not support OHLCV data fetching") return [] # Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor try: if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, since=since, limit=limit) else: # Run in executor to avoid blocking loop = asyncio.get_event_loop() ohlcv = await loop.run_in_executor( None, lambda: exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) ) except Exception as e: logging.error(f"Failed to fetch OHLCV data: {e}") await asyncio.sleep(retry_delay) continue if not ohlcv or len(ohlcv) == 0: logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})") await asyncio.sleep(retry_delay) continue # Filter candles if until timestamp is provided if until is not None: ohlcv = [candle for candle in ohlcv if candle[0] <= until] # Convert to list of lists format data = [] for candle in ohlcv: timestamp, open_price, high, low, close, volume = candle data.append([timestamp, open_price, high, low, close, volume]) logging.info(f"Successfully fetched {len(data)} historical candles") return data except Exception as e: logging.error(f"Error fetching historical OHLCV data (attempt {attempt+1}/{max_retries}): {e}") if attempt < max_retries - 1: await asyncio.sleep(retry_delay) logging.error(f"Failed to fetch historical OHLCV data after {max_retries} attempts") return [] def format_timestamp(self, timestamp): """Format a timestamp for readable logging""" if timestamp is None: return "None" try: dt = datetime.datetime.fromtimestamp(timestamp / 1000.0) return dt.strftime('%Y-%m-%d %H:%M:%S') except: return str(timestamp) async def train_with_backtesting(agent, env, symbol="ETH/USDT", since_timestamp=None, until_timestamp=None, num_episodes=10, max_steps_per_episode=1000, period_name=None): """ Train agent with backtesting on historical data. Args: agent: The agent to train env: Trading environment symbol: Trading pair symbol since_timestamp: Start timestamp for backtesting until_timestamp: End timestamp for backtesting num_episodes: Number of episodes to train max_steps_per_episode: Maximum steps per episode period_name: Name of the backtest period Returns: Training statistics dictionary """ # Create a backtesting candle cache backtest_cache = BacktestCandles(since_timestamp, until_timestamp) if period_name: backtest_cache.period_name = period_name logging.info(f"Starting backtesting for period: {period_name}") # Initialize exchange for data fetching exchange = None try: exchange = await initialize_exchange() logging.info("Initialized exchange for backtesting") except Exception as e: logging.error(f"Failed to initialize exchange: {e}") return None # Initialize statistics tracking stats = { 'period': period_name, 'since_timestamp': since_timestamp, 'until_timestamp': until_timestamp, 'episode_rewards': [], 'episode_lengths': [], 'balances': [], 'win_rates': [], 'episode_pnls': [], 'cumulative_pnl': [], 'drawdowns': [], 'trade_counts': [], 'loss_values': [], 'fees': [], 'net_pnl_after_fees': [] } # Memory management function def clean_memory(): """Clean up memory to avoid memory leaks""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Fetch historical data for all timeframes try: clean_memory() # Clean memory before fetching data candle_data = await backtest_cache.fetch_all_timeframes(exchange, symbol) if not candle_data or not candle_data['1m']: logging.error(f"No historical data available for backtesting period: {period_name}") try: await exchange.close() except Exception as e: logging.error(f"Error closing exchange: {e}") return None logging.info(f"Fetched historical data for backtesting: {len(candle_data['1m'])} minute candles") except Exception as e: logging.error(f"Failed to fetch historical data for backtesting: {e}") try: await exchange.close() except Exception as exchange_err: logging.error(f"Error closing exchange: {exchange_err}") return None # Track best models best_reward = float('-inf') best_pnl = float('-inf') best_net_pnl = float('-inf') # Make directory for backtesting models if it doesn't exist os.makedirs('models/backtest', exist_ok=True) # Start backtesting training loop for episode in range(num_episodes): try: # Clean memory before starting a new episode clean_memory() # Reset environment state = env.reset() episode_reward = 0 episode_losses = [] # Update CNN patterns with historical data env.update_cnn_patterns(candle_data) # Track consecutive errors for circuit breaker consecutive_errors = 0 max_consecutive_errors = 5 # Episode loop for step in range(max_steps_per_episode): try: # Select action using CNN-enhanced policy action = agent.select_action(state, training=True, candle_data=candle_data) # Take action next_state, reward, done, info = env.step(action) # Store transition in replay memory agent.memory.push(state, action, reward, next_state, done) # Move to the next state state = next_state # Update episode reward episode_reward += reward # Learn from experience if len(agent.memory) > BATCH_SIZE: try: loss = agent.learn() if loss is not None: episode_losses.append(loss) # Reset consecutive errors counter on successful learning consecutive_errors = 0 except Exception as e: logging.error(f"Error during learning: {e}") consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") break # Update target network periodically if step % TARGET_UPDATE == 0: agent.update_target_network() # Clean memory periodically during long episodes if step % 200 == 0 and step > 0: clean_memory() # End episode if done if done: break except Exception as e: logging.error(f"Error in training step: {e}") consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") break # Calculate statistics mean_loss = np.mean(episode_losses) if episode_losses else 0 balance = env.balance pnl = balance - env.initial_balance fees = env.total_fees net_pnl = pnl - fees win_rate = env.win_rate if hasattr(env, 'win_rate') else 0 trade_count = env.trade_count if hasattr(env, 'trade_count') else 0 # Update epsilon for exploration epsilon = agent.update_epsilon(episode) # Update statistics stats['episode_rewards'].append(episode_reward) stats['episode_lengths'].append(step + 1) stats['balances'].append(balance) stats['win_rates'].append(win_rate) stats['episode_pnls'].append(pnl) stats['drawdowns'].append(env.max_drawdown) stats['trade_counts'].append(trade_count) stats['loss_values'].append(mean_loss) stats['fees'].append(fees) stats['net_pnl_after_fees'].append(net_pnl) # Calculate and update cumulative PnL if len(stats['episode_pnls']) > 0: cumulative_pnl = sum(stats['episode_pnls']) if 'cumulative_pnl' not in stats: stats['cumulative_pnl'] = [] stats['cumulative_pnl'].append(cumulative_pnl) if writer: writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode) # Save model if this is the best reward or PnL if episode_reward > best_reward: best_reward = episode_reward model_path = f"models/backtest/{period_name}_best_reward.pt" if period_name else "models/backtest/best_reward.pt" try: agent.save(model_path) logging.info(f"New best reward: {best_reward:.2f}") except Exception as e: logging.error(f"Error saving best reward model: {e}") logging.info(f"New best reward: {best_reward:.2f} (model not saved)") if pnl > best_pnl: best_pnl = pnl model_path = f"models/backtest/{period_name}_best_pnl.pt" if period_name else "models/backtest/best_pnl.pt" try: agent.save(model_path) logging.info(f"New best PnL: ${best_pnl:.2f}") except Exception as e: logging.error(f"Error saving best PnL model: {e}") logging.info(f"New best PnL: ${best_pnl:.2f} (model not saved)") # Save model if this is the best net PnL (after fees) if net_pnl > best_net_pnl: best_net_pnl = net_pnl model_path = f"models/backtest/{period_name}_best_net_pnl.pt" if period_name else "models/backtest/best_net_pnl.pt" try: agent.save(model_path) logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") except Exception as e: logging.error(f"Error saving best net PnL model: {e}") logging.info(f"New best Net PnL: ${best_net_pnl:.2f} (model not saved)") # Save checkpoint periodically if episode % 10 == 0: try: if use_compact_save: compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt') else: agent.save(f'models/trading_agent_checkpoint_{episode}.pt') except Exception as e: logging.error(f"Error saving checkpoint model: {e}") # Update epsilon agent.update_epsilon(episode) # Log training progress logging.info(f"Episode {episode+1}/{num_episodes} | " + f"Reward: {episode_reward:.2f} | " + f"Balance: ${balance:.2f} | " + f"PnL: ${pnl:.2f} | " + f"Fees: ${fees:.2f} | " + f"Net PnL: ${net_pnl:.2f} | " + f"Win Rate: {win_rate:.2f} | " + f"Trades: {trade_count} | " + f"Loss: {mean_loss:.5f} | " + f"Epsilon: {agent.epsilon:.4f}") except Exception as e: logging.error(f"Error in episode {episode}: {e}") logging.error(traceback.format_exc()) continue # Clean memory before saving final model clean_memory() # Save final model if period_name: try: agent.save(f"models/backtest/{period_name}_final.pt") logging.info(f"Saved final model for period: {period_name}") except Exception as e: logging.error(f"Error saving final model: {e}") # Save backtesting statistics stats_file = f"backtest_stats_{period_name}.csv" if period_name else "backtest_stats.csv" try: with open(stats_file, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['Episode', 'Reward', 'Balance', 'PnL', 'Fees', 'Net PnL', 'Win Rate', 'Trades', 'Loss']) for i in range(len(stats['episode_rewards'])): writer.writerow([ i+1, stats['episode_rewards'][i], stats['balances'][i], stats['episode_pnls'][i], stats['fees'][i], stats['net_pnl_after_fees'][i], stats['win_rates'][i], stats['trade_counts'][i], stats['loss_values'][i] ]) logging.info(f"Backtesting statistics saved to {stats_file}") except Exception as e: logging.error(f"Error saving backtesting statistics: {e}") # Close exchange connection if exchange: try: await exchange.close() logging.info("Exchange connection closed successfully") except AttributeError: # Some exchanges don't have a close method logging.info("Exchange doesn't have a close method, skipping") except Exception as e: logging.error(f"Error closing exchange connection: {e}") return stats # Implement a robust save function to handle PyTorch serialization errors def robust_save(model, path): """ Save a model with multiple fallback approaches to ensure file is saved even in low disk space conditions. """ logger.info(f"Saving model to {path}.backup (attempt 1)") backup_path = f"{path}.backup" # Attempt 1: Regular save to backup file try: checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'optimizer': model.optimizer.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, backup_path) logger.info(f"Successfully saved to {backup_path}") # If successful, copy to final path try: shutil.copy2(backup_path, path) logger.info(f"Copied backup to {path}") logger.info(f"Model saved successfully to {path}") return True except Exception as e: logger.warning(f"Failed to copy backup to main file: {str(e)}") logger.info(f"Using backup file as the main save") return True except Exception as e: logger.warning(f"First save attempt failed: {str(e)}") # Attempt 2: Try with older pickle protocol logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") try: checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'optimizer': model.optimizer.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) logger.info(f"Successfully saved to {path} with protocol 2") return True except Exception as e: logger.warning(f"Second save attempt failed: {str(e)}") # Attempt 3: Try without optimizer logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") try: checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) logger.info(f"Successfully saved to {path} without optimizer") return True except Exception as e: logger.warning(f"Third save attempt failed: {str(e)}") # Attempt 4: Save model structure (as JSON) and parameters separately logger.info(f"Saving model to {path} (attempt 4 - model structure as JSON)") try: # Save only essential model parameters as JSON model_params = { 'epsilon': float(model.epsilon), 'state_size': model.state_size, 'action_size': model.action_size, 'hidden_size': model.hidden_size, 'lstm_layers': model.policy_net.lstm_layers if hasattr(model.policy_net, 'lstm_layers') else 2, 'attention_heads': model.policy_net.attention_heads if hasattr(model.policy_net, 'attention_heads') else 4 } params_path = f"{path}.params.json" with open(params_path, 'w') as f: json.dump(model_params, f) logger.info(f"Successfully saved model parameters to {params_path}") # Now try to save a smaller version of the model without CNN components # This is a more minimal save for recovery purposes try: # Create stripped down checkpoint with minimal components minimal_checkpoint = { 'epsilon': model.epsilon, 'state_size': model.state_size, 'action_size': model.action_size, 'hidden_size': model.hidden_size } minimal_path = f"{path}.minimal" torch.save(minimal_checkpoint, minimal_path, _use_new_zipfile_serialization=False, pickle_protocol=2) logger.info(f"Successfully saved minimal checkpoint to {minimal_path}") except Exception as e: logger.warning(f"Minimal checkpoint save failed: {str(e)}") logger.info(f"Model saved successfully to {path}") return True except Exception as e: logger.error(f"All save attempts failed for {path}: {str(e)}") return False def cleanup_model_files(keep_best=True, keep_latest_n=5, aggressive=False): """ Delete old model files to free up disk space. Args: keep_best (bool): Whether to keep the best model files (reward, pnl, net_pnl) keep_latest_n (int): Number of latest checkpoint files to keep aggressive (bool): If True, apply more aggressive cleanup in very low disk scenarios """ try: logging.info(f"Running model file cleanup: keep_best={keep_best}, keep_latest_n={keep_latest_n}, aggressive={aggressive}") models_dir = "models" # Get all files in the models directory all_files = os.listdir(models_dir) # Files to potentially delete checkpoint_files = [] backup_files = [] params_files = [] dated_files = [] # Best files to keep if keep_best is True best_patterns = [ "trading_agent_best_reward.pt", "trading_agent_best_pnl.pt", "trading_agent_best_net_pnl.pt", "trading_agent_final.pt" ] # Categorize files for potential deletion for filename in all_files: file_path = os.path.join(models_dir, filename) # Skip directories if os.path.isdir(file_path): continue # Skip current best files if keep_best is True if keep_best and any(filename == pattern for pattern in best_patterns): continue # Check for different file types if "checkpoint" in filename and filename.endswith(".pt"): checkpoint_files.append((filename, os.path.getmtime(file_path), file_path)) elif filename.endswith(".backup"): backup_files.append((filename, os.path.getmtime(file_path), file_path)) elif filename.endswith(".params.json"): params_files.append((filename, os.path.getmtime(file_path), file_path)) elif "_2025" in filename or "_2024" in filename: # Files with date stamps dated_files.append((filename, os.path.getmtime(file_path), file_path)) bytes_freed = 0 files_deleted = 0 # Process checkpoint files - keep the newest N if len(checkpoint_files) > keep_latest_n: # Sort by modification time (newest first) checkpoint_files.sort(key=lambda x: x[1], reverse=True) # Keep the newest N files files_to_delete = checkpoint_files[keep_latest_n:] # Delete old checkpoint files for _, _, file_path in files_to_delete: try: file_size = os.path.getsize(file_path) os.remove(file_path) bytes_freed += file_size files_deleted += 1 logging.info(f"Deleted old checkpoint file: {file_path}") except Exception as e: logging.error(f"Failed to delete file {file_path}: {str(e)}") # If aggressive cleanup is enabled, remove more files if aggressive: # Delete all backup files except the newest one if backup_files: backup_files.sort(key=lambda x: x[1], reverse=True) for _, _, file_path in backup_files[1:]: # Keep only newest backup try: file_size = os.path.getsize(file_path) os.remove(file_path) bytes_freed += file_size files_deleted += 1 logging.info(f"Deleted old backup file: {file_path}") except Exception as e: logging.error(f"Failed to delete file {file_path}: {str(e)}") # Delete all dated files (these are typically archived models) for _, _, file_path in dated_files: try: file_size = os.path.getsize(file_path) os.remove(file_path) bytes_freed += file_size files_deleted += 1 logging.info(f"Deleted dated model file: {file_path}") except Exception as e: logging.error(f"Failed to delete file {file_path}: {str(e)}") logging.info(f"Cleanup complete. Deleted {files_deleted} files, freed {bytes_freed / (1024*1024):.2f} MB") # Check available disk space after cleanup try: if platform.system() == 'Windows': free_bytes = ctypes.c_ulonglong(0) ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(os.path.abspath(models_dir)), None, None, ctypes.pointer(free_bytes)) free_mb = free_bytes.value / (1024 * 1024) else: st = os.statvfs(os.path.abspath(models_dir)) free_mb = (st.f_bavail * st.f_frsize) / (1024 * 1024) logging.info(f"Available disk space after cleanup: {free_mb:.2f} MB") # If space is still low, recommend aggressive cleanup if free_mb < 200 and not aggressive: # Less than 200MB available logging.warning("Disk space still critically low. Consider using aggressive cleanup.") except Exception as e: logging.error(f"Error checking disk space: {str(e)}") except Exception as e: logging.error(f"Error during file cleanup: {str(e)}") logging.error(traceback.format_exc()) def compact_save(model, optimizer, reward, epsilon, state_size, action_size, hidden_size, path, use_quantization=False): """ Save a model in a compact format suitable for low disk space environments. Includes fallbacks if the primary save method fails. Args: model: The model to save optimizer: The optimizer to save reward: The current reward epsilon: The current epsilon value state_size: The state size action_size: The action size hidden_size: The hidden size path: The path to save to use_quantization: Whether to use quantization to reduce model size Returns: bool: Whether the save was successful """ try: # Create minimal checkpoint with essential data only checkpoint = { 'model_state_dict': model.state_dict(), 'epsilon': epsilon, 'state_size': state_size, 'action_size': action_size, 'hidden_size': hidden_size } # Apply quantization if requested if use_quantization: try: logging.info(f"Attempting quantized save to {path}") # Quantize model to int8 quantized_model = torch.quantization.quantize_dynamic( model, # the original model {torch.nn.Linear}, # a set of layers to dynamically quantize dtype=torch.qint8 # the target dtype for quantized weights ) # Create quantized checkpoint quantized_checkpoint = { 'model_state_dict': quantized_model.state_dict(), 'epsilon': epsilon, 'state_size': state_size, 'action_size': action_size, 'hidden_size': hidden_size, 'is_quantized': True } # Save with older pickle protocol and disable new zipfile serialization torch.save(quantized_checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) logging.info(f"Quantized compact save successful to {path}") return True except Exception as e: logging.warning(f"Quantized save failed, falling back to regular save: {str(e)}") # Fall back to regular save if quantization fails # Regular save with older pickle protocol and no zipfile serialization torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) logging.info(f"Compact save successful to {path}") return True except Exception as e: logging.error(f"Compact save failed: {str(e)}") logging.error(traceback.format_exc()) # Fallback: Save just the parameters as JSON if we can't save the full model try: params = { 'epsilon': epsilon, 'state_size': state_size, 'action_size': action_size, 'hidden_size': hidden_size } json_path = f"{path}.params.json" with open(json_path, 'w') as f: json.dump(params, f) logging.info(f"Saved minimal parameters to {json_path}") return False except Exception as json_e: logging.error(f"JSON parameter save failed: {str(json_e)}") return False if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser(description='Trading Bot') parser.add_argument('--mode', type=str, default='train', help='Mode: train, test, live') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train') parser.add_argument('--max_steps', type=int, default=1000, help='Maximum steps per episode') parser.add_argument('--update_interval', type=int, default=10, help='Target network update interval') parser.add_argument('--training_iterations', type=int, default=10, help='Number of training iterations per step') parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol') parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for candlestick data') parser.add_argument('--compact_save', action='store_true', help='Use compact save to reduce disk usage') parser.add_argument('--use_quantization', action='store_true', help='Use model quantization for even smaller file sizes') parser.add_argument('--cleanup', action='store_true', help='Clean up old model files before training') parser.add_argument('--aggressive_cleanup', action='store_true', help='Perform aggressive cleanup to free more space') parser.add_argument('--keep_latest', type=int, default=5, help='Number of latest checkpoint files to keep when cleaning up') args = parser.parse_args() # Import platform and ctypes for disk space checking import platform import ctypes # Run cleanup if requested if args.cleanup: cleanup_model_files(keep_best=True, keep_latest_n=args.keep_latest, aggressive=args.aggressive_cleanup) try: asyncio.run(main()) except KeyboardInterrupt: logger.info("Program terminated by user")