gogo2/main.py
Dobromir Popov 3871afd4b8 init
2025-03-18 09:23:09 +02:00

5166 lines
215 KiB
Python

import os
import time
import json
import numpy as np
import pandas as pd
from datetime import datetime
import random
import logging
import asyncio
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
from dotenv import load_dotenv
import ccxt
import websockets
from torch.utils.tensorboard import SummaryWriter
import torch.cuda.amp as amp # Add this import at the top
from sklearn.preprocessing import MinMaxScaler
import copy
import argparse
import traceback
import io
import matplotlib.dates as mdates
from matplotlib.figure import Figure
from PIL import Image
import matplotlib.pyplot as mpf
import matplotlib.gridspec as gridspec
import datetime
from datetime import datetime as dt
from collections import defaultdict
from gym.spaces import Discrete, Box
import csv
import gc
import shutil
import math
import platform
import ctypes
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
)
logger = logging.getLogger("trading_bot")
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
# Constants
INITIAL_BALANCE = 100 # USD
MAX_LEVERAGE = 100
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
MEMORY_SIZE = 100000
BATCH_SIZE = 64
GAMMA = 0.99 # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 10000
STATE_SIZE = 64 # Size of our state representation
LEARNING_RATE = 1e-4
TARGET_UPDATE = 10 # Update target network every 10 episodes
# Experience replay tuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
# Add this function near the top of the file, after the imports but before any classes
def find_local_extrema(prices, window=5):
"""
Find local extrema (tops and bottoms) in price series.
Args:
prices: Array of price values
window: Window size for finding extrema
Returns:
Tuple of (tops, bottoms) indices
"""
tops = []
bottoms = []
if len(prices) < window * 2 + 1:
return tops, bottoms
try:
# Use peak detection algorithms from scipy if available
from scipy.signal import find_peaks
# Find peaks (tops)
peaks, _ = find_peaks(prices, distance=window)
tops = list(peaks)
# Find valleys (bottoms) by inverting the prices
valleys, _ = find_peaks(-prices, distance=window)
bottoms = list(valleys)
# Optional: Filter extrema for significance
if len(tops) > 0 and len(bottoms) > 0:
# Calculate average price move
avg_move = np.mean(np.abs(np.diff(prices)))
# Filter tops and bottoms for significant moves
filtered_tops = []
for top in tops:
# Check if this top is significantly higher than surrounding points
if top > window and top < len(prices) - window:
surrounding_min = min(prices[top-window:top+window])
if prices[top] - surrounding_min > avg_move * 1.5: # 1.5x average move
filtered_tops.append(top)
filtered_bottoms = []
for bottom in bottoms:
# Check if this bottom is significantly lower than surrounding points
if bottom > window and bottom < len(prices) - window:
surrounding_max = max(prices[bottom-window:bottom+window])
if surrounding_max - prices[bottom] > avg_move * 1.5: # 1.5x average move
filtered_bottoms.append(bottom)
tops = filtered_tops
bottoms = filtered_bottoms
except ImportError:
# Fallback to manual detection if scipy is not available
for i in range(window, len(prices) - window):
# Check if this point is a local maximum
if all(prices[i] >= prices[i - j] for j in range(1, window + 1)) and \
all(prices[i] >= prices[i + j] for j in range(1, window + 1)):
tops.append(i)
# Check if this point is a local minimum
if all(prices[i] <= prices[i - j] for j in range(1, window + 1)) and \
all(prices[i] <= prices[i + j] for j in range(1, window + 1)):
bottoms.append(i)
return tops, bottoms
class ReplayMemory:
def __init__(self, capacity):
self.memory = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.memory.append(Experience(state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
"""Deep Q-Network with enhanced architecture"""
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
super(DQN, self).__init__()
self.network = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads)
self.hidden_size = hidden_size
self.lstm_layers = lstm_layers
self.attention_heads = attention_heads
def forward(self, state, x_1s=None, x_1m=None, x_1h=None, x_1d=None):
# Pass through to LSTMAttentionDQN
if x_1m is not None and x_1h is not None and x_1d is not None:
return self.network(state, x_1s, x_1m, x_1h, x_1d)
else:
return self.network(state)
class PricePredictionModel(nn.Module):
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
super(PricePredictionModel, self).__init__()
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
self.fc = nn.Linear(hidden_size, output_size)
self.scaler = MinMaxScaler(feature_range=(0, 1))
self.is_fitted = False
def forward(self, x):
# x shape: [batch_size, seq_len, 1]
lstm_out, _ = self.lstm(x)
# Use the last time step output
predictions = self.fc(lstm_out[:, -1, :])
return predictions
def preprocess(self, data):
# Reshape data for scaler
data_reshaped = np.array(data).reshape(-1, 1)
# Fit scaler if not already fitted
if not self.is_fitted:
self.scaler.fit(data_reshaped)
self.is_fitted = True
# Transform data
scaled_data = self.scaler.transform(data_reshaped)
return scaled_data
def postprocess(self, scaled_predictions):
# Inverse transform to get actual price values
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
def predict_next_candles(self, price_history, num_candles=5):
if len(price_history) < 30: # Need enough history
return np.zeros(num_candles)
# Preprocess data
scaled_data = self.preprocess(price_history)
# Create sequence
sequence = scaled_data[-30:].reshape(1, 30, 1)
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
# Get predictions
with torch.no_grad():
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
# Postprocess predictions
predictions = self.postprocess(scaled_predictions)
return predictions
def train_on_new_data(self, price_history, optimizer, epochs=10):
if len(price_history) < 35: # Need enough history for training
return 0.0
# Preprocess data
scaled_data = self.preprocess(price_history)
# Create sequences and targets
sequences = []
targets = []
for i in range(len(scaled_data) - 35):
# Sequence: 30 time steps
seq = scaled_data[i:i+30]
# Target: next 5 time steps
target = scaled_data[i+30:i+35].flatten()
sequences.append(seq)
targets.append(target)
if not sequences: # If no sequences were created
return 0.0
# Convert to tensors
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
# Training loop
total_loss = 0
for _ in range(epochs):
# Forward pass
predictions = self(sequences_tensor)
# Calculate loss
loss = F.mse_loss(predictions, targets_tensor)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / epochs
class TradingEnvironment:
def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50,
window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance',
symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000):
"""Initialize the trading environment"""
self.api_key = api_key
self.api_secret = api_secret
self.exchange_id = exchange_id
self.symbol = symbol
self.timeframe = timeframe
self.init_length = init_length
# TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data
try:
# Initialize exchange if API credentials are provided
if api_key and api_secret:
self.exchange = initialize_exchange(exchange_id, api_key, api_secret)
logger.info(f"Exchange initialized: {exchange_id}")
# Fetch historical data
self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length)
if not self.data:
raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}")
self.data_format_is_list = isinstance(self.data[0], list)
logger.info(f"Loaded {len(self.data)} candles from exchange")
elif data is not None: # Use provided data
self.data = data
self.data_format_is_list = isinstance(self.data[0], list)
logger.info(f"Using provided data with {len(self.data)} candles")
else:
# Initialize with empty data, we'll load it later with fetch_initial_data
logger.warning("No data provided, initializing with empty data")
self.data = []
self.data_format_is_list = True
except Exception as e:
logger.error(f"Error initializing environment: {e}")
raise
# Initialize features and feature extractors
if features is not None:
self.features = features
# Create a dictionary of features
self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)}
else:
# Initialize features as a dictionary, not a list
self.features = {
'price': [],
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
self.features_dict = {}
if feature_extractors is None:
feature_extractors = []
self.feature_extractors = feature_extractors
# Environment parameters
self.initial_balance = initial_balance
self.balance = initial_balance
self.leverage = leverage
self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.commission = commission
self.total_pnl = 0
self.total_fees = 0.0 # Track total fees paid
self.trades = []
self.trade_signals = []
self.current_step = 0
self.window_size = window_size
self.max_steps = max_steps
self.peak_balance = initial_balance
self.max_drawdown = 0
self.current_price = 0
self.win_count = 0
self.loss_count = 0
self.min_position_size = 100 # Minimum position size in USD
# Track candle patterns and reversal points
self.patterns = {}
self.reversal_points = []
# Define observation and action spaces
num_features = len(self.features) if hasattr(self, 'features') and self.features else 0
state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features
self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
# Check if we have enough data
if len(self.data) < self.window_size:
logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}")
def calculate_reward(self, action):
"""Calculate reward based on the action taken"""
reward = 0
# Base reward structure
if self.position == 'flat':
if action == 0: # HOLD when flat
reward = 0.01 # Small reward for holding when no position
elif action == 1: # BUY/LONG
# Check for buy signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
buy_confidence = self.cnn_patterns['long_confidence']
# Scale by confidence
reward = 0.1 * buy_confidence * 10
else:
reward = 0.1 # Default reward for taking a position
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 2: # SELL/SHORT
# Check for sell signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
sell_confidence = self.cnn_patterns['short_confidence']
# Scale by confidence
reward = 0.1 * sell_confidence * 10
else:
reward = 0.1 # Default reward for taking a position
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 3: # CLOSE when no position
reward = -0.1 # Penalty for trying to close no position
elif self.position == 'long':
if action == 0: # HOLD long position
# Calculate price change since entry
price_change = (self.current_price - self.entry_price) / self.entry_price
# Reward or penalize based on price movement
if price_change > 0:
reward = price_change * 10 # Reward for holding profitable position
else:
reward = price_change * 5 # Smaller penalty for holding losing position
elif action == 1: # BUY when already long
reward = -0.1 # Penalty for redundant action
elif action == 2: # SELL when long (reversal)
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
if pnl_percent > 0:
reward = -0.5 # Penalty for closing profitable long position to go short
else:
# Check for sell signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
sell_confidence = self.cnn_patterns['short_confidence']
reward = 0.2 * sell_confidence * 10 # Reward for correct reversal
else:
reward = 0.2 # Default reward for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 3: # CLOSE long position
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
if pnl_percent > 0:
reward = pnl_percent * 15 # Higher reward for taking profit
else:
reward = pnl_percent * 5 # Smaller penalty for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif self.position == 'short':
if action == 0: # HOLD short position
# Calculate price change since entry
price_change = (self.entry_price - self.current_price) / self.entry_price
# Reward or penalize based on price movement
if price_change > 0:
reward = price_change * 10 # Reward for holding profitable position
else:
reward = price_change * 5 # Smaller penalty for holding losing position
elif action == 1: # BUY when short (reversal)
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
if pnl_percent > 0:
reward = -0.5 # Penalty for closing profitable short position to go long
else:
# Check for buy signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
buy_confidence = self.cnn_patterns['long_confidence']
reward = 0.2 * buy_confidence * 10 # Reward for correct reversal
else:
reward = 0.2 # Default reward for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 2: # SELL when already short
reward = -0.1 # Penalty for redundant action
elif action == 3: # CLOSE short position
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
if pnl_percent > 0:
reward = pnl_percent * 15 # Higher reward for taking profit
else:
reward = pnl_percent * 5 # Smaller penalty for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
return reward
def reset(self):
"""Reset the environment to its initial state and return the initial observation"""
self.balance = self.initial_balance
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.current_step = 0
self.trades = []
self.trade_signals = []
self.total_pnl = 0.0
self.total_fees = 0.0
self.peak_balance = self.initial_balance
self.max_drawdown = 0.0
self.win_count = 0
self.loss_count = 0
return self.get_state()
def add_data(self, candle):
"""Add a new candle to the data"""
# Check if candle is a list or dictionary
if isinstance(candle, list):
self.data_format_is_list = True
self.data.append(candle)
self.current_price = candle[4] # Close price is at index 4
else:
self.data_format_is_list = False
self.data.append(candle)
self.current_price = candle['close']
self._update_features()
def _initialize_features(self):
"""Initialize technical indicators and features"""
if len(self.data) < 30:
return
# Convert data to pandas DataFrame for easier calculation
if self.data_format_is_list:
# Convert list format to DataFrame
df = pd.DataFrame(self.data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
else:
# Dictionary format
df = pd.DataFrame(self.data)
# Basic price and volume
self.features['price'] = df['close'].values
self.features['volume'] = df['volume'].values
# Calculate RSI (14 periods)
delta = df['close'].diff()
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
rs = gain / loss
self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values
# Calculate MACD
ema12 = df['close'].ewm(span=12, adjust=False).mean()
ema26 = df['close'].ewm(span=26, adjust=False).mean()
macd = ema12 - ema26
signal = macd.ewm(span=9, adjust=False).mean()
self.features['macd'] = macd.values
self.features['macd_signal'] = signal.values
self.features['macd_hist'] = (macd - signal).values
# Calculate Bollinger Bands
sma20 = df['close'].rolling(window=20).mean()
std20 = df['close'].rolling(window=20).std()
self.features['bollinger_upper'] = (sma20 + 2 * std20).values
self.features['bollinger_mid'] = sma20.values
self.features['bollinger_lower'] = (sma20 - 2 * std20).values
# Calculate Stochastic Oscillator
low_14 = df['low'].rolling(window=14).min()
high_14 = df['high'].rolling(window=14).max()
k = 100 * ((df['close'] - low_14) / (high_14 - low_14))
self.features['stoch_k'] = k.values
self.features['stoch_d'] = k.rolling(window=3).mean().values
# Calculate EMAs
self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values
self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values
# Calculate ATR
high_low = df['high'] - df['low']
high_close = (df['high'] - df['close'].shift()).abs()
low_close = (df['low'] - df['close'].shift()).abs()
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values
def _update_features(self):
"""Update technical indicators with new data"""
self._initialize_features() # Recalculate all features
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""Fetch initial historical data for the environment"""
try:
logger.info(f"Fetching initial data for {symbol}")
# Use the refactored fetch method
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
# Update environment with fetched data
if data:
self.data = data
self._initialize_features()
logger.info(f"Initialized environment with {len(data)} candles")
else:
logger.warning("No initial data received")
return len(data) > 0
except Exception as e:
logger.error(f"Error fetching initial data: {e}")
return False
def step(self, action):
"""Take an action in the environment and return the next state, reward, and done flag"""
# Check if we have enough data
if self.current_step >= len(self.data) - 1:
# We've reached the end of data
done = True
next_state = self.get_state()
info = {
'action': 'none',
'price': self.current_price,
'balance': self.balance,
'position': self.position,
'pnl': self.total_pnl
}
return next_state, 0, done, info
# Circuit breaker after consecutive losses
if self.count_consecutive_losses() >= 5:
logger.warning("Circuit breaker triggered after 5 consecutive losses")
return self.get_state(), -1, True, {'action': 'circuit_breaker_triggered'}
# Reduce leverage in volatile markets
if self.is_volatile_market():
self.leverage = MAX_LEVERAGE * 0.5 # Half leverage in volatile markets
else:
self.leverage = MAX_LEVERAGE
# Store current price before taking action
if self.data_format_is_list:
self.current_price = self.data[self.current_step][4] # Close price
else:
self.current_price = self.data[self.current_step]['close']
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
reward = self.calculate_reward(action)
# Execute the action
initial_balance = self.balance # Store initial balance to calculate PnL
# Open long position
if action == 1 and self.position != 'long':
if self.position == 'short':
# Close short position first
if self.position_size > 0:
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'action_change'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Now open long position
self.position = 'long'
self.entry_price = self.current_price
self.entry_index = self.current_step
# Calculate position size with risk management
self.position_size = self.calculate_position_size()
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Set stop loss and take profit
sl_percent = 0.02 # 2% stop loss
tp_percent = 0.04 # 4% take profit
self.stop_loss = self.entry_price * (1 - sl_percent)
self.take_profit = self.entry_price * (1 + tp_percent)
# Open short position
elif action == 2 and self.position != 'short':
if self.position == 'long':
# Close long position first
if self.position_size > 0:
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'action_change'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Now open short position
self.position = 'short'
self.entry_price = self.current_price
self.entry_index = self.current_step
# Calculate position size with risk management
self.position_size = self.calculate_position_size()
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Set stop loss and take profit
sl_percent = 0.02 # 2% stop loss
tp_percent = 0.04 # 4% take profit
self.stop_loss = self.entry_price * (1 + sl_percent)
self.take_profit = self.entry_price * (1 - tp_percent)
# Close position
elif action == 3 and self.position != 'flat':
if self.position == 'long':
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'close_action'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
elif self.position == 'short':
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'close_action'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
# Record trade signal for visualization
if action > 0: # If not HOLD
signal_type = None
if action == 1: # BUY/LONG
signal_type = 'buy'
elif action == 2: # SELL/SHORT
signal_type = 'sell'
elif action == 3: # CLOSE
if self.position == 'long':
signal_type = 'close_long'
elif self.position == 'short':
signal_type = 'close_short'
if signal_type:
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trade_signals.append({
'timestamp': timestamp,
'price': self.current_price,
'type': signal_type,
'balance': self.balance,
'pnl': self.total_pnl
})
# Check for stop loss / take profit hits
self.check_sl_tp()
# Move to next step
self.current_step += 1
done = self.current_step >= len(self.data) - 1
# Get new state
next_state = self.get_state()
# Update peak balance and drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0
self.max_drawdown = max(self.max_drawdown, current_drawdown)
# Create info dictionary
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': self.current_price,
'balance': self.balance,
'position': self.position,
'pnl': self.total_pnl,
'fees': self.total_fees,
'net_pnl': self.total_pnl - self.total_fees
}
return next_state, reward, done, info
def check_sl_tp(self):
"""Check if stop loss or take profit has been hit with improved trailing stop"""
if self.position == 'flat':
return
if self.position == 'long':
# Implement trailing stop loss if in profit
if self.current_price > self.entry_price * 1.01:
self.stop_loss = max(self.stop_loss, self.current_price * 0.995) # Trail at 0.5% below current price
# Check stop loss
if self.current_price <= self.stop_loss:
# Stop loss hit
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
trade_duration = self.current_step - self.entry_index
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.stop_loss,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': trade_duration,
'timestamp': self.data[self.current_step]['timestamp'],
'reason': 'stop_loss'
})
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif self.current_price >= self.take_profit:
# Take profit hit
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
trade_duration = self.current_step - self.entry_index
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.take_profit,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': trade_duration,
'timestamp': self.data[self.current_step]['timestamp'],
'reason': 'take_profit'
})
self.win_count += 1
logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
elif self.position == 'short':
# Implement trailing stop loss if in profit
if self.current_price < self.entry_price * 0.99:
self.stop_loss = min(self.stop_loss, self.current_price * 1.005) # Trail at 0.5% above current price
# Check stop loss
if self.current_price >= self.stop_loss:
# Stop loss hit
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
trade_duration = self.current_step - self.entry_index
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.stop_loss,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': trade_duration,
'timestamp': self.data[self.current_step]['timestamp'],
'reason': 'stop_loss'
})
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif self.current_price <= self.take_profit:
# Take profit hit
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
trade_duration = self.current_step - self.entry_index
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.take_profit,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': trade_duration,
'timestamp': self.data[self.current_step]['timestamp'],
'reason': 'take_profit'
})
self.win_count += 1
logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
def get_state(self):
"""Create state representation for the agent with enhanced features"""
# Ensure we have enough data
if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0:
# Return zeros if not enough data
return np.zeros(STATE_SIZE)
# Create a normalized state vector with recent price action and indicators
state_components = []
# Safely get the latest price
try:
latest_price = self.features['price'][-1]
except IndexError:
# If we can't get the latest price, return zeros
return np.zeros(STATE_SIZE)
# Safely get price features
try:
# Price features (normalize recent prices by the latest price)
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
state_components.append(price_features)
except (IndexError, ZeroDivisionError):
# If we can't get price features, use zeros
state_components.append(np.zeros(10))
# Safely get volume features
try:
# Volume features (normalize by max volume)
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
vol_features = np.array(self.features['volume'][-5:]) / max_vol
state_components.append(vol_features)
except (IndexError, ZeroDivisionError):
# If we can't get volume features, use zeros
state_components.append(np.zeros(5))
# Technical indicators
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
state_components.append(rsi)
# MACD (normalize)
macd_vals = np.array(self.features['macd'][-3:])
macd_signal = np.array(self.features['macd_signal'][-3:])
macd_hist = np.array(self.features['macd_hist'][-3:])
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
macd_norm = macd_vals / macd_scale
macd_signal_norm = macd_signal / macd_scale
macd_hist_norm = macd_hist / macd_scale
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
# Bollinger position (where is price relative to bands)
bb_upper = np.array(self.features['bollinger_upper'][-3:])
bb_lower = np.array(self.features['bollinger_lower'][-3:])
bb_mid = np.array(self.features['bollinger_mid'][-3:])
price = np.array(self.features['price'][-3:])
# Calculate position of price within Bollinger Bands (0 to 1)
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
state_components.append(np.array(bb_pos))
# Stochastic oscillator
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
# Add predicted prices (if available)
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
# Normalize predictions relative to current price
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
state_components.append(pred_norm)
else:
# Add zeros if no predictions
state_components.append(np.zeros(3))
# Add extrema signals (if available)
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
# Get recent signals
idx = len(self.optimal_signals) - 5
if idx < 0:
idx = 0
recent_signals = self.optimal_signals[idx:idx+5]
# Pad if needed
if len(recent_signals) < 5:
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
state_components.append(recent_signals)
else:
# Add zeros if no signals
state_components.append(np.zeros(5))
# Position info
position_info = np.zeros(5)
if self.position == 'long':
position_info[0] = 1.0 # Position is long
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
position_info[4] = self.position_size / self.balance # Position size relative to balance
elif self.position == 'short':
position_info[0] = -1.0 # Position is short
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
position_info[4] = self.position_size / self.balance # Position size relative to balance
state_components.append(position_info)
# NEW FEATURES START HERE
# 1. Price momentum features (rate of change)
if len(self.features['price']) >= 20:
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
momentum_features = np.array([roc_5, roc_10, roc_20])
state_components.append(momentum_features)
else:
state_components.append(np.zeros(3))
# 2. Volatility features
if len(self.features['price']) >= 20:
# Calculate price returns
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
# Calculate volatility (standard deviation of returns)
volatility = np.std(returns)
# Calculate normalized high-low range based on data format
if self.data_format_is_list:
# List format: high at index 2, low at index 3, close at index 4
high_low_range = np.mean([
(self.data[i][2] - self.data[i][3]) / self.data[i][4]
for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1))
]) if len(self.data) > 0 else 0
else:
# Dictionary format
high_low_range = np.mean([
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1))
]) if len(self.data) > 0 else 0
# ATR normalized by price
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
volatility_features = np.array([volatility, high_low_range, atr_norm])
state_components.append(volatility_features)
else:
state_components.append(np.zeros(3))
# 3. Market regime features
if len(self.features['price']) >= 50:
# Trend strength (ADX-like measure)
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
trend_strength = abs(ema9 - ema21) / ema21
# Detect if in range or trending
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
# Detect if near support/resistance
near_support = 1.0 if self.is_near_support() else 0.0
near_resistance = 1.0 if self.is_near_resistance() else 0.0
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
state_components.append(market_regime)
else:
state_components.append(np.zeros(5))
# 4. Trade history features
if len(self.trades) > 0:
# Recent win/loss ratio
recent_trades = self.trades[-min(10, len(self.trades)):]
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
# Average profit/loss
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
# Normalize by balance
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
# Last trade result
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
state_components.append(trade_history)
else:
state_components.append(np.zeros(4))
# Combine all features
state = np.concatenate([comp.flatten() for comp in state_components])
# Replace any NaN or infinite values
state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure the state has the correct size
if len(state) != STATE_SIZE:
logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}")
# Pad or truncate to match expected size
if len(state) < STATE_SIZE:
state = np.pad(state, (0, STATE_SIZE - len(state)))
else:
state = state[:STATE_SIZE]
return state
def get_expanded_state_size(self):
"""Calculate the size of the expanded state representation"""
# Create a dummy state to get its size
state = self.get_state()
return len(state)
async def expand_model_with_new_features(agent, env):
"""Expand the model to handle new features without retraining from scratch"""
# Get the new state size
new_state_size = env.get_expanded_state_size()
# Only expand if the new state size is larger
if new_state_size > agent.state_size:
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
# Expand the model
success = agent.expand_model(
new_state_size=new_state_size,
new_hidden_size=512, # Increase hidden size for more capacity
new_lstm_layers=3, # More layers for deeper patterns
new_attention_heads=8 # More attention heads for complex relationships
)
if success:
logger.info(f"Model successfully expanded to handle {new_state_size} features")
return True
else:
logger.error("Failed to expand model")
return False
else:
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
return True
def calculate_reward(self, action):
"""
Calculate reward for taking the given action.
Args:
action: The action taken (0=hold, 1=buy, 2=sell, 3=close)
Returns:
The calculated reward
"""
reward = 0
# Get current price
if self.data_format_is_list:
current_price = self.data[self.current_step][4] # Close price
else:
current_price = self.data[self.current_step]['close']
# Base reward component based on price movement
price_change_pct = 0
if self.current_step > 0:
if self.data_format_is_list:
prev_price = self.data[self.current_step-1][4] # Previous close
else:
prev_price = self.data[self.current_step-1]['close']
price_change_pct = (current_price - prev_price) / prev_price
# Check if we have CNN patterns available
pattern_confidence = 0
if hasattr(self, 'cnn_patterns'):
if action == 1 and 'long_confidence' in self.cnn_patterns: # Buy action
pattern_confidence = self.cnn_patterns['long_confidence']
elif action == 2 and 'short_confidence' in self.cnn_patterns: # Sell action
pattern_confidence = self.cnn_patterns['short_confidence']
# Action-specific rewards
if action == 0: # HOLD
# Small positive reward for holding in the right direction of market movement
if self.position == 'long' and price_change_pct > 0:
reward += 0.1 + price_change_pct * 10
elif self.position == 'short' and price_change_pct < 0:
reward += 0.1 + abs(price_change_pct) * 10
else:
# Small negative reward for holding in the wrong direction
reward -= 0.1
elif action == 1 or action == 2: # BUY or SELL
# Apply trading fee as negative reward (1 USD per 1.9k position size)
position_size = self.calculate_position_size()
fee = (position_size / 1900) * 1 # Trading fee in USD
# Penalty for fee
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
reward -= fee_penalty
# Logging
if hasattr(self, 'total_fees'):
self.total_fees += fee
else:
self.total_fees = fee
elif action == 3: # CLOSE
# Apply trading fee as negative reward (1 USD per 1.9k position size)
fee = (self.position_size / 1900) * 1 # Trading fee in USD
# Penalty for fee
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
reward -= fee_penalty
# Logging
if hasattr(self, 'total_fees'):
self.total_fees += fee
else:
self.total_fees = fee
# Add CNN pattern confidence to reward
reward += pattern_confidence * 10
return reward
async def initialize_futures(self, exchange):
"""Initialize futures trading parameters"""
if not self.demo:
try:
# Set up futures trading parameters
await exchange.set_position_mode(True) # Hedge mode
await exchange.set_margin_mode("cross", symbol=self.futures_symbol)
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
logger.info(f"Futures initialized with {self.leverage}x leverage")
except Exception as e:
logger.error(f"Failed to initialize futures trading: {str(e)}")
logger.info("Falling back to demo mode for safety")
demo = True
async def execute_real_trade(self, exchange, action, current_price):
"""Execute real futures trade on MEXC"""
try:
position_size = self.calculate_position_size()
if action == 1: # Open long
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='buy',
amount=position_size,
params={'positionSide': 'LONG'}
)
logger.info(f"Opened LONG position: {order}")
elif action == 2: # Open short
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='sell',
amount=position_size,
params={'positionSide': 'SHORT'}
)
logger.info(f"Opened SHORT position: {order}")
elif action == 3: # Close position
position_side = 'LONG' if self.position == 'long' else 'SHORT'
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='sell' if position_side == 'LONG' else 'buy',
amount=self.position_size,
params={'positionSide': position_side}
)
logger.info(f"Closed {position_side} position: {order}")
return order
except Exception as e:
logger.error(f"Trade execution failed: {e}")
return None
def trades_in_last_n_candles(self, n=20):
"""Count the number of trades in the last n candles"""
if len(self.trades) == 0:
return 0
if self.data_format_is_list:
# List format: timestamp at index 0
current_time = self.data[self.current_step][0]
n_candles_ago = self.data[max(0, self.current_step - n)][0]
else:
# Dictionary format
current_time = self.data[self.current_step]['timestamp']
n_candles_ago = self.data[max(0, self.current_step - n)]['timestamp']
count = 0
for trade in reversed(self.trades):
if 'timestamp' in trade and trade['timestamp'] >= n_candles_ago and trade['timestamp'] <= current_time:
count += 1
else:
# Older trades, we can stop counting
break
return count
def count_consecutive_losses(self):
"""Count the number of consecutive losing trades"""
count = 0
for trade in reversed(self.trades):
if trade.get('pnl_dollar', 0) < 0:
count += 1
else:
break
return count
def is_volatile_market(self):
"""Determine if the current market is volatile"""
if len(self.features['price']) < 20:
return False
recent_prices = self.features['price'][-20:]
avg_price = sum(recent_prices) / len(recent_prices)
volatility = sum([abs(p - avg_price) / avg_price for p in recent_prices]) / len(recent_prices)
return volatility > 0.01 # 1% average deviation is considered volatile
def is_uptrend(self):
"""Determine if the market is in an uptrend"""
if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2:
return False
# Short-term trend
short_trend = self.features['ema_9'][-1] > self.features['ema_9'][-2]
# Medium-term trend
medium_trend = self.features['ema_9'][-1] > self.features['ema_21'][-1]
return short_trend and medium_trend
def is_downtrend(self):
"""Determine if the market is in a downtrend"""
if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2:
return False
# Short-term trend
short_trend = self.features['ema_9'][-1] < self.features['ema_9'][-2]
# Medium-term trend
medium_trend = self.features['ema_9'][-1] < self.features['ema_21'][-1]
return short_trend and medium_trend
def calculate_position_size(self):
"""Calculate position size based on risk management rules"""
# Reduce position size after losses
consecutive_losses = self.count_consecutive_losses()
risk_factor = max(0.3, 1.0 - (consecutive_losses * 0.1)) # Reduce by 10% per loss, min 30%
# Calculate position size based on available balance and risk
max_risk_amount = self.balance * 0.02 # Risk 2% per trade
position_size = max_risk_amount / (STOP_LOSS_PERCENT / 100 * self.current_price)
# Apply leverage
position_size = position_size * self.leverage
# Cap at available balance
position_size = min(position_size, self.balance * self.leverage)
return position_size * risk_factor
# ... existing identify_optimal_trades method ...
def update_cnn_patterns(self, candle_data=None):
"""
Update CNN patterns using multi-timeframe data.
Args:
candle_data: Dictionary containing candle data for different timeframes
"""
if not candle_data:
return
try:
# Check if we have the necessary timeframes
required_timeframes = ['1m', '1h', '1d']
if not all(tf in candle_data for tf in required_timeframes):
logging.warning(f"Missing required timeframes for CNN pattern detection")
return
# Initialize patterns if not already done
if not hasattr(self, 'cnn_patterns'):
self.cnn_patterns = {}
# Extract features from candle data
features = {}
# Process each timeframe
for tf in required_timeframes:
candles = candle_data[tf]
if not candles or len(candles) < 30:
continue
# Convert to numpy arrays for easier processing
closes = np.array([c[4] for c in candles[-100:]])
highs = np.array([c[2] for c in candles[-100:]])
lows = np.array([c[3] for c in candles[-100:]])
# Simple feature extraction
# 1. Detect trends
ema20 = self._calculate_ema(closes, 20)
ema50 = self._calculate_ema(closes, 50)
uptrend = ema20[-1] > ema50[-1] and closes[-1] > ema20[-1]
downtrend = ema20[-1] < ema50[-1] and closes[-1] < ema20[-1]
# 2. Detect potential reversal patterns
# Find local extrema
tops, bottoms = find_local_extrema(closes, window=10)
# Check if we're near a bottom (potential buy)
near_bottom = False
bottom_confidence = 0
if bottoms and len(bottoms) > 0:
last_bottom = bottoms[-1]
if len(closes) - last_bottom < 5: # Recent bottom
bottom_dist = abs(closes[-1] - closes[last_bottom]) / closes[last_bottom]
if bottom_dist < 0.01: # Within 1% of the bottom
near_bottom = True
# Higher confidence if volume is increasing
bottom_confidence = 0.8 - bottom_dist * 50 # 0.8 to 0.3 range
# Check if we're near a top (potential sell)
near_top = False
top_confidence = 0
if tops and len(tops) > 0:
last_top = tops[-1]
if len(closes) - last_top < 5: # Recent top
top_dist = abs(closes[-1] - closes[last_top]) / closes[last_top]
if top_dist < 0.01: # Within 1% of the top
near_top = True
# Higher confidence if volume is increasing
top_confidence = 0.8 - top_dist * 50 # 0.8 to 0.3 range
# Store features for this timeframe
features[tf] = {
'uptrend': uptrend,
'downtrend': downtrend,
'near_bottom': near_bottom,
'bottom_confidence': bottom_confidence,
'near_top': near_top,
'top_confidence': top_confidence
}
# Combine features across timeframes to get overall pattern confidence
long_confidence = 0
short_confidence = 0
# Weight each timeframe (higher weight for longer timeframes)
weights = {'1m': 0.2, '1h': 0.3, '1d': 0.5}
for tf, tf_features in features.items():
weight = weights.get(tf, 0.2)
# Add to long confidence
if tf_features['uptrend'] or tf_features['near_bottom']:
long_confidence += weight * (0.6 if tf_features['uptrend'] else 0) + \
weight * (tf_features['bottom_confidence'] if tf_features['near_bottom'] else 0)
# Add to short confidence
if tf_features['downtrend'] or tf_features['near_top']:
short_confidence += weight * (0.6 if tf_features['downtrend'] else 0) + \
weight * (tf_features['top_confidence'] if tf_features['near_top'] else 0)
# Normalize confidence scores to [0, 1]
long_confidence = min(1.0, long_confidence)
short_confidence = min(1.0, short_confidence)
# Update patterns
self.cnn_patterns = {
'long_confidence': long_confidence,
'short_confidence': short_confidence,
'features': features
}
logging.debug(f"Updated CNN patterns - Long: {long_confidence:.2f}, Short: {short_confidence:.2f}")
except Exception as e:
logging.error(f"Error updating CNN patterns: {e}")
def _calculate_ema(self, data, span):
"""Calculate exponential moving average"""
alpha = 2 / (span + 1)
alpha_rev = 1 - alpha
ema = np.zeros_like(data)
ema[0] = data[0]
for i in range(1, len(data)):
ema[i] = alpha * data[i] + alpha_rev * ema[i-1]
return ema
def is_uncertain_market(self):
"""Determine if the market is in an uncertain/range-bound state"""
if len(self.features['price']) < 30:
return False
# Check if EMAs are close to each other (no clear trend)
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
ema9 = self.features['ema_9'][-1]
ema21 = self.features['ema_21'][-1]
# If EMAs are within 0.2% of each other, market is uncertain
if abs(ema9 - ema21) / ema21 < 0.002:
return True
# Check if price is oscillating without clear direction
if len(self.features['price']) >= 10:
recent_prices = self.features['price'][-10:]
ups = downs = 0
for i in range(1, len(recent_prices)):
if recent_prices[i] > recent_prices[i-1]:
ups += 1
else:
downs += 1
# If there's a mix of ups and downs (neither dominates heavily)
return abs(ups - downs) < 3
return False
def is_near_support(self):
"""Determine if the current price is near a support level"""
if len(self.features['price']) < 30:
return False
# Use Bollinger lower band as support
if len(self.features['bollinger_lower']) > 0 and len(self.features['price']) > 0:
current_price = self.features['price'][-1]
lower_band = self.features['bollinger_lower'][-1]
# If price is within 0.5% of the lower band
if (current_price - lower_band) / current_price < 0.005:
return True
# Check if we're near recent lows
if len(self.features['price']) >= 20:
current_price = self.features['price'][-1]
min_price = min(self.features['price'][-20:])
# If within 1% of recent lows
if (current_price - min_price) / current_price < 0.01:
return True
return False
def is_near_resistance(self):
"""Determine if the current price is near a resistance level"""
if len(self.features['price']) < 30:
return False
# Use Bollinger upper band as resistance
if len(self.features['bollinger_upper']) > 0 and len(self.features['price']) > 0:
current_price = self.features['price'][-1]
upper_band = self.features['bollinger_upper'][-1]
# If price is within 0.5% of the upper band
if (upper_band - current_price) / current_price < 0.005:
return True
# Check if we're near recent highs
if len(self.features['price']) >= 20:
current_price = self.features['price'][-1]
max_price = max(self.features['price'][-20:])
# If within 1% of recent highs
if (max_price - current_price) / current_price < 0.01:
return True
return False
def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'):
"""
Add a candlestick chart and metrics to TensorBoard
Parameters:
- writer: TensorBoard writer
- step: Current step
- title: Title for the chart
"""
try:
# Initialize writer if not provided
if writer is None:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
# Log basic metrics
writer.add_scalar('Balance', self.balance, step)
writer.add_scalar('Total_PnL', self.total_pnl, step)
# Log total fees if available
if hasattr(self, 'total_fees'):
writer.add_scalar('Total_Fees', self.total_fees, step)
writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step)
# Log position info
writer.add_scalar('Position_Size', self.position_size, step)
# Log drawdown and win rate
writer.add_scalar('Max_Drawdown', self.max_drawdown, step)
win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0
writer.add_scalar('Win_Rate', win_rate, step)
# Log trade count
writer.add_scalar('Trade_Count', len(self.trades), step)
# Check if we have enough data for candlestick chart
if len(self.data) <= 0:
logger.warning("No data available for candlestick chart")
return
# Create figure for candlestick chart (last 100 data points)
start_idx = max(0, self.current_step - 100)
end_idx = self.current_step
# Get recent trades for visualization (last 10 trades)
recent_trades = self.trades[-10:] if self.trades else []
try:
fig = create_candlestick_figure(
self.data[start_idx:end_idx+1],
title=title,
trades=recent_trades
)
# Add figure to TensorBoard
writer.add_figure('Candlestick_Chart', fig, step)
# Close figure to free memory
plt.close(fig)
except Exception as e:
logger.error(f"Error creating candlestick chart: {e}")
except Exception as e:
logger.error(f"Error adding chart to TensorBoard: {e}")
# Continue execution even if chart fails
def get_realtime_state(self, tick_data):
"""
Create a state representation optimized for real-time processing.
This is a streamlined version of get_state() designed for minimal latency.
TODO: Implement optimized state creation from tick data
"""
# This would be a simplified version of get_state that processes only
# the most important features needed for real-time decision making
# Example implementation:
# realtime_features = {
# 'price': tick_data['price'],
# 'volume': tick_data['volume'],
# 'ema_short': self._calculate_ema(tick_data['price'], 9),
# 'ema_long': self._calculate_ema(tick_data['price'], 21),
# }
# Convert to tensor or numpy array in the required format
# return torch.tensor([...], dtype=torch.float32)
# Placeholder
return np.zeros((self.observation_space.shape[0],), dtype=np.float32)
# Ensure GPU usage if available
def get_device():
"""Get the best available device (CUDA GPU or CPU)"""
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
# Set up for mixed precision training
torch.backends.cudnn.benchmark = True
else:
device = torch.device("cpu")
logger.info("GPU not available, using CPU")
return device
# Update Agent class to use GPU properly
class Agent:
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
"""Initialize Agent with architecture parameters stored as attributes"""
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
# Set device
self.device = device if device is not None else get_device()
# Initialize networks - use LSTMAttentionDQN instead of DQN
self.policy_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
# Initialize optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Initialize replay memory
self.memory = ReplayMemory(MEMORY_SIZE)
# Initialize exploration parameters
self.epsilon = EPSILON_START
self.epsilon_start = EPSILON_START
self.epsilon_end = EPSILON_END
self.epsilon_decay = EPSILON_DECAY
self.epsilon_min = EPSILON_END
# Initialize step counter
self.steps_done = 0
# Initialize TensorBoard writer
self.writer = None
# Initialize GradScaler for mixed precision training
self.scaler = torch.amp.GradScaler('cuda') if self.device.type == "cuda" else None
# Initialize candle cache for multi-timeframe data
self.candle_cache = CandleCache()
# Store model name for logging
self.model_name = f"LSTM_Attention_DQN_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info(f"Initialized agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
logger.info(f"Using device: {self.device}")
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
"""Expand the model to handle more features or increase capacity"""
logger.info(f"Expanding model: {self.state_size}{new_state_size}, "
f"hidden: {self.policy_net.hidden_size}{new_hidden_size}")
# Save old weights
old_state_dict = self.policy_net.state_dict()
# Create new larger networks
new_policy_net = LSTMAttentionDQN(new_state_size, self.action_size,
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
new_target_net = LSTMAttentionDQN(new_state_size, self.action_size,
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
# Transfer weights for common layers
new_state_dict = new_policy_net.state_dict()
for name, param in old_state_dict.items():
if name in new_state_dict:
# If shapes match, copy directly
if new_state_dict[name].shape == param.shape:
new_state_dict[name] = param
# For first layer, copy weights for the original input dimensions
elif name == "fc1.weight":
new_state_dict[name][:, :self.state_size] = param
# For other layers, initialize with a strategy that preserves scale
else:
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
# Load transferred weights
new_policy_net.load_state_dict(new_state_dict)
new_target_net.load_state_dict(new_state_dict)
# Replace networks
self.policy_net = new_policy_net
self.target_net = new_target_net
self.target_net.eval()
# Update optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Update state size
self.state_size = new_state_size
# Print new model size
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"New model size: {total_params:,} parameters")
return True
def select_action(self, state, training=True, candle_data=None):
"""
Select an action using the policy network.
Args:
state: The current state
training: Whether we're in training mode (for epsilon-greedy)
candle_data: Dictionary with ['1s'-later], '1m', '1h', '1d' candle data
Returns:
The selected action
"""
# ... existing code ...
# Add CNN processing if candle data is available
cnn_inputs = None
if candle_data and all(k in candle_data for k in [ '1m', '1h', '1d']):
# Process candle data into tensors
# x_1s = self.prepare_candle_tensor(candle_data['1s'])
x_1m = self.prepare_candle_tensor(candle_data['1m'])
x_1h = self.prepare_candle_tensor(candle_data['1h'])
x_1d = self.prepare_candle_tensor(candle_data['1d'])
cnn_inputs = (x_1m, x_1h, x_1d)
# Use epsilon-greedy strategy during training
if training and random.random() < self.epsilon:
return random.randrange(self.action_size)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
if cnn_inputs:
q_values = self.policy_net(state_tensor, *cnn_inputs)
else:
q_values = self.policy_net(state_tensor)
return q_values.max(1)[1].item()
def prepare_candle_tensor(self, candles, max_candles=300):
"""Convert candle data to tensors for CNN input"""
if not candles:
# Return zeros if no candles available
return torch.zeros((1, 5, max_candles), device=self.device)
# Limit to the most recent candles
candles = candles[-max_candles:]
# Extract OHLCV data
ohlcv = np.array([[c[1], c[2], c[3], c[4], c[5]] for c in candles], dtype=np.float32)
# Normalize the data
if len(ohlcv) > 0:
# Simple min-max normalization per column
min_vals = ohlcv.min(axis=0, keepdims=True)
max_vals = ohlcv.max(axis=0, keepdims=True)
range_vals = max_vals - min_vals
range_vals[range_vals == 0] = 1 # Avoid division by zero
ohlcv = (ohlcv - min_vals) / range_vals
# Pad if needed
padded = np.zeros((max_candles, 5), dtype=np.float32)
padded[-len(ohlcv):] = ohlcv
# Convert to tensor [batch, channels, sequence]
tensor = torch.FloatTensor(padded.transpose(1, 0)).unsqueeze(0).to(self.device)
return tensor
else:
return torch.zeros((1, 5, max_candles), device=self.device)
def learn(self):
"""Learn from a batch of experiences with GPU acceleration and CNN features"""
if len(self.memory) < BATCH_SIZE:
return None
try:
# Sample a batch of experiences
experiences = self.memory.sample(BATCH_SIZE)
# Convert experiences to tensors
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
# Use mixed precision for forward/backward passes if on GPU
if self.device.type == "cuda" and self.scaler is not None:
with torch.cuda.amp.autocast():
# Compute current Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Compute next Q values
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
# Compute target Q values
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
target_q_values = target_q_values.unsqueeze(1)
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Backward pass with gradient scaling
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
# Clip gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
# Update weights
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard precision for CPU
# Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Compute next Q values with target network
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
# Reshape target values to match current_q_values
target_q_values = target_q_values.unsqueeze(1)
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
# Update steps done
self.steps_done += 1
# Update target network
if self.steps_done % TARGET_UPDATE == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
return loss.item()
except Exception as e:
logger.error(f"Error during learning: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
return None
def update_epsilon(self, episode):
"""Update epsilon value based on episode number"""
# Calculate epsilon using a linear decay formula
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
max(0, (self.epsilon_decay - episode)) / self.epsilon_decay
# Update self.epsilon with the calculated value
self.epsilon = max(self.epsilon_min, epsilon)
return self.epsilon
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path="models/trading_agent_best_pnl.pt"):
"""Save the model using a robust saving approach with multiple fallbacks"""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
# Call robust save function
success = robust_save(self, path)
if success:
logger.info(f"Model saved successfully to {path}")
return True
else:
logger.error(f"All save attempts failed for path: {path}")
return False
except Exception as e:
logger.error(f"Error in save method: {e}")
logger.error(traceback.format_exc())
return False
def load(self, path="models/trading_agent_best_pnl.pt"):
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
try:
# First try to load with weights_only=False (for models saved with older PyTorch versions)
try:
logger.info(f"Attempting to load model with weights_only=False: {path}")
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
logger.info("Model loaded successfully with weights_only=False")
except Exception as e1:
logger.warning(f"Failed to load with weights_only=False: {e1}")
# Try with safe_globals context manager
try:
logger.info("Attempting to load with safe_globals context manager")
import numpy as np
from torch.serialization import safe_globals
# Add numpy scalar to safe globals
with safe_globals(['numpy._core.multiarray.scalar']):
checkpoint = torch.load(path, map_location=self.device)
logger.info("Model loaded successfully with safe_globals")
except Exception as e2:
logger.warning(f"Failed to load with safe_globals: {e2}")
# Last resort: try with pickle_module=pickle
logger.info("Attempting to load with pickle_module")
import pickle
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
logger.info("Model loaded successfully with pickle_module")
# Load state dictionaries
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
# Try to load optimizer state
try:
self.optimizer.load_state_dict(checkpoint['optimizer'])
except Exception as e:
logger.warning(f"Could not load optimizer state: {e}")
# Load epsilon if available
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon']
# Load architecture parameters if available
if 'state_size' in checkpoint:
self.state_size = checkpoint['state_size']
if 'action_size' in checkpoint:
self.action_size = checkpoint['action_size']
if 'hidden_size' in checkpoint:
self.hidden_size = checkpoint['hidden_size']
else:
# If hidden_size not in checkpoint, infer from model
try:
self.hidden_size = self.policy_net.fc1.weight.shape[0]
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
except:
self.hidden_size = 256 # Default value
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
if 'lstm_layers' in checkpoint:
self.lstm_layers = checkpoint['lstm_layers']
else:
self.lstm_layers = 2 # Default value
if 'attention_heads' in checkpoint:
self.attention_heads = checkpoint['attention_heads']
else:
self.attention_heads = 4 # Default value
logger.info(f"Model loaded successfully from {path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def add_chart_to_tensorboard(self, env, step):
"""Add candlestick chart to tensorboard and various metrics"""
try:
# Initialize writer if it doesn't exist
if not hasattr(self, 'writer') or self.writer is None:
self.writer = SummaryWriter(log_dir=f'runs/{self.model_name}')
# Check if we have enough data
if not hasattr(env, 'data') or len(env.data) < 20:
logger.warning("Not enough data for chart in TensorBoard")
return
# Get position value (convert from string if needed)
position_value = 0 # Default to flat
if hasattr(env, 'position'):
if isinstance(env.position, str):
# Map string positions to numeric values
position_map = {'flat': 0, 'long': 1, 'short': -1}
position_value = position_map.get(env.position.lower(), 0)
else:
position_value = float(env.position)
# Log metrics to tensorboard
self.writer.add_scalar('Trading/Position', position_value, step)
if hasattr(env, 'balance'):
self.writer.add_scalar('Trading/Balance', env.balance, step)
if hasattr(env, 'total_pnl'):
self.writer.add_scalar('Trading/Total_PnL', env.total_pnl, step)
if hasattr(env, 'max_drawdown'):
self.writer.add_scalar('Trading/Drawdown', env.max_drawdown, step)
if hasattr(env, 'win_rate'):
self.writer.add_scalar('Trading/Win_Rate', env.win_rate, step)
if hasattr(env, 'trade_count'):
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
# Log trading fees
if hasattr(env, 'total_fees'):
self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step)
# Also log net PnL (after fees)
if hasattr(env, 'total_pnl'):
self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step)
# Add candlestick chart if we have enough data
if len(env.data) >= 100:
try:
# Use the last 100 candles for the chart
recent_data = env.data[-100:]
# Get recent trades if available
recent_trades = None
if hasattr(env, 'trades') and len(env.trades) > 0:
recent_trades = env.trades[-10:] # Last 10 trades
# Create candlestick figure
fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}")
if fig:
# Add to tensorboard
self.writer.add_figure('Trading/Chart', fig, step)
# Close figure to free memory
plt.close(fig)
except Exception as e:
logger.warning(f"Error creating candlestick chart: {e}")
except Exception as e:
logger.error(f"Error in add_chart_to_tensorboard: {e}")
def select_action_realtime(self, state):
"""
Select action with minimal latency for real-time trading.
Optimized version of select_action for ultra-low latency requirements.
TODO: Implement optimized action selection for real-time trading
"""
# Convert to tensor if needed
state_tensor = torch.tensor(state, dtype=torch.float32)
# Fast forward pass through the network
with torch.no_grad():
q_values = self.policy_net.forward_realtime(state_tensor.unsqueeze(0))
# Get the action with highest Q-value
action = q_values.max(1)[1].item()
return action
def forward_realtime(self, state):
"""
Optimized forward pass for real-time trading with minimal latency.
TODO: Implement streamlined forward pass that prioritizes speed
"""
# For now, just use the regular forward pass
# This could be optimized later with techniques like:
# - Using a smaller model for real-time decisions
# - Skipping certain layers or calculations
# - Using quantized weights or other optimizations
return self.forward(state)
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
"""Get live price data using websockets"""
# Connect to MEXC websocket
uri = "wss://stream.mexc.com/ws"
async with websockets.connect(uri) as websocket:
# Subscribe to kline data
subscribe_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
}
await websocket.send(json.dumps(subscribe_msg))
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
while True:
try:
response = await websocket.recv()
data = json.loads(response)
if 'data' in data:
kline = data['data']
candle = {
'timestamp': kline['t'],
'open': float(kline['o']),
'high': float(kline['h']),
'low': float(kline['l']),
'close': float(kline['c']),
'volume': float(kline['v'])
}
yield candle
except Exception as e:
logger.error(f"Websocket error: {e}")
# Try to reconnect
await asyncio.sleep(5)
break
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, use_compact_save=False):
"""
Train the agent in the environment.
Args:
agent: The agent to train
env: The trading environment
num_episodes: Number of episodes to train for
max_steps_per_episode: Maximum steps per episode
use_compact_save: Whether to use compact save (for low disk space)
Returns:
Training statistics
"""
# Initialize TensorBoard writer if not already done
try:
if agent.writer is None:
from torch.utils.tensorboard import SummaryWriter
agent.writer = SummaryWriter(log_dir=f'runs/{agent.model_name}')
writer = agent.writer
except Exception as e:
logging.error(f"Failed to initialize TensorBoard: {e}")
writer = None
# Initialize exchange for data fetching
try:
exchange = await initialize_exchange()
logging.info("Initialized exchange for data fetching")
except Exception as e:
logging.error(f"Failed to initialize exchange: {e}")
exchange = None
# Initialize statistics tracking
stats = {
'episode_rewards': [],
'episode_lengths': [],
'balances': [],
'win_rates': [],
'episode_pnls': [],
'cumulative_pnl': [],
'drawdowns': [],
'trade_counts': [],
'loss_values': [],
'fees': [], # Track fees
'net_pnl_after_fees': [] # Track net PnL after fees
}
# Track best models
best_reward = float('-inf')
best_pnl = float('-inf')
best_net_pnl = float('-inf') # Track best net PnL (after fees)
# Make directory for models if it doesn't exist
os.makedirs('models', exist_ok=True)
# Memory management function
def clean_memory():
"""Clean up memory to avoid memory leaks"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Start training loop
for episode in range(num_episodes):
try:
# Clean up memory before starting a new episode
clean_memory()
# Reset environment
state = env.reset()
episode_reward = 0
episode_losses = []
# Fetch multi-timeframe data at the start of the episode
candle_data = None
if exchange:
try:
candle_data = await fetch_multi_timeframe_data(
exchange, "ETH/USDT", agent.candle_cache
)
# Update CNN patterns
env.update_cnn_patterns(candle_data)
logging.info(f"Fetched multi-timeframe data for episode {episode+1}")
except Exception as e:
logging.error(f"Failed to fetch candle data: {e}")
# Track consecutive errors
consecutive_errors = 0
max_consecutive_errors = 5
# Episode loop
for step in range(max_steps_per_episode):
try:
# Select action using CNN-enhanced policy
action = agent.select_action(state, training=True, candle_data=candle_data)
# Take action
next_state, reward, done, info = env.step(action)
# Store transition in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
# Update episode reward
episode_reward += reward
# Learn from experience
if len(agent.memory) > BATCH_SIZE:
try:
loss = agent.learn()
if loss is not None:
episode_losses.append(loss)
# Log loss to TensorBoard
global_step = episode * max_steps_per_episode + step
if writer:
writer.add_scalar('Loss/step', loss, global_step)
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logging.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update target network periodically
if step % TARGET_UPDATE == 0:
agent.update_target_network()
# Update price predictions and CNN patterns periodically
if step % 50 == 0:
try:
# Update internal environment predictions
if hasattr(env, 'update_price_predictions'):
env.update_price_predictions()
if hasattr(env, 'identify_optimal_trades'):
env.identify_optimal_trades()
# Fetch fresh candle data periodically
if exchange:
try:
candle_data = await fetch_multi_timeframe_data(
exchange, "ETH/USDT", agent.candle_cache
)
# Update CNN patterns with the new candle data
env.update_cnn_patterns(candle_data)
logging.info(f"Updated multi-timeframe data at step {step}")
except Exception as e:
logging.error(f"Failed to fetch candle data: {e}")
except Exception as e:
logging.warning(f"Error updating predictions: {e}")
# Clean memory periodically during long episodes
if step % 200 == 0 and step > 0:
clean_memory()
# Add chart to TensorBoard periodically
if step % 100 == 0 or (step == max_steps_per_episode - 1) or done:
try:
global_step = episode * max_steps_per_episode + step
if writer:
agent.add_chart_to_tensorboard(env, global_step)
except Exception as e:
logging.warning(f"Error adding chart to TensorBoard: {e}")
if done:
break
except Exception as e:
logging.error(f"Error in training step: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Calculate statistics from this episode
balance = env.balance
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
fees = env.total_fees if hasattr(env, 'total_fees') else 0
net_pnl = pnl - fees # Calculate net PnL after fees
# Get trading statistics
trade_analysis = None
if hasattr(env, 'analyze_trades'):
trade_analysis = env.analyze_trades()
win_rate = trade_analysis['win_rate'] if trade_analysis and 'win_rate' in trade_analysis else 0
trade_count = trade_analysis['total_trades'] if trade_analysis and 'total_trades' in trade_analysis else 0
max_drawdown = trade_analysis['max_drawdown'] if trade_analysis and 'max_drawdown' in trade_analysis else 0
# Calculate average loss for this episode
avg_loss = sum(episode_losses) / len(episode_losses) if episode_losses else 0
# Log episode metrics to TensorBoard
if writer:
writer.add_scalar('Reward/episode', episode_reward, episode)
writer.add_scalar('Balance/episode', balance, episode)
writer.add_scalar('PnL/episode', pnl, episode)
writer.add_scalar('NetPnL/episode', net_pnl, episode)
writer.add_scalar('Fees/episode', fees, episode)
writer.add_scalar('WinRate/episode', win_rate, episode)
writer.add_scalar('TradeCount/episode', trade_count, episode)
writer.add_scalar('Drawdown/episode', max_drawdown, episode)
writer.add_scalar('Loss/episode', avg_loss, episode)
writer.add_scalar('Epsilon/episode', agent.epsilon, episode)
# Update stats dictionary
stats['episode_rewards'].append(episode_reward)
stats['episode_lengths'].append(step + 1)
stats['balances'].append(balance)
stats['win_rates'].append(win_rate)
stats['episode_pnls'].append(pnl)
stats['drawdowns'].append(max_drawdown)
stats['trade_counts'].append(trade_count)
stats['loss_values'].append(avg_loss)
stats['fees'].append(fees)
stats['net_pnl_after_fees'].append(net_pnl)
# Calculate and update cumulative PnL
if len(stats['episode_pnls']) > 0:
cumulative_pnl = sum(stats['episode_pnls'])
if 'cumulative_pnl' not in stats:
stats['cumulative_pnl'] = []
stats['cumulative_pnl'].append(cumulative_pnl)
if writer:
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
# Save model if this is the best reward or PnL
if episode_reward > best_reward:
best_reward = episode_reward
try:
if use_compact_save:
success = compact_save(agent, 'models/trading_agent_best_reward.pt')
else:
success = agent.save('models/trading_agent_best_reward.pt')
if success:
logging.info(f"New best reward: {best_reward:.2f}")
except Exception as e:
logging.error(f"Error saving best reward model: {e}")
if pnl > best_pnl:
best_pnl = pnl
try:
if use_compact_save:
success = compact_save(agent, 'models/trading_agent_best_pnl.pt')
else:
success = agent.save('models/trading_agent_best_pnl.pt')
if success:
logging.info(f"New best PnL: ${best_pnl:.2f}")
except Exception as e:
logging.error(f"Error saving best PnL model: {e}")
# Save model if this is the best net PnL (after fees)
if net_pnl > best_net_pnl:
best_net_pnl = net_pnl
try:
if use_compact_save:
success = compact_save(agent, 'models/trading_agent_best_net_pnl.pt')
else:
success = agent.save('models/trading_agent_best_net_pnl.pt')
if success:
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
except Exception as e:
logging.error(f"Error saving best net PnL model: {e}")
# Save checkpoint periodically
if episode % 10 == 0:
try:
if use_compact_save:
compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt')
else:
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
except Exception as e:
logging.error(f"Error saving checkpoint model: {e}")
# Update epsilon
agent.update_epsilon(episode)
# Log training progress
logging.info(f"Episode {episode+1}/{num_episodes} | " +
f"Reward: {episode_reward:.2f} | " +
f"Balance: ${balance:.2f} | " +
f"PnL: ${pnl:.2f} | " +
f"Fees: ${fees:.2f} | " +
f"Net PnL: ${net_pnl:.2f} | " +
f"Win Rate: {win_rate:.2f} | " +
f"Trades: {trade_count} | " +
f"Loss: {avg_loss:.5f} | " +
f"Epsilon: {agent.epsilon:.4f}")
except Exception as e:
logging.error(f"Error in episode {episode}: {e}")
logging.error(traceback.format_exc())
continue
# Clean memory before saving final model
clean_memory()
# Save final model
try:
if use_compact_save:
compact_save(agent, 'models/trading_agent_final.pt')
else:
agent.save('models/trading_agent_final.pt')
except Exception as e:
logging.error(f"Error saving final model: {e}")
# Save training statistics to file
try:
import pandas as pd
# Make sure all arrays in stats are the same length by padding with NaN
max_length = max(len(v) for k, v in stats.items() if isinstance(v, list))
for k, v in stats.items():
if isinstance(v, list) and len(v) < max_length:
stats[k] = v + [float('nan')] * (max_length - len(v))
# Create dataframe and save
stats_df = pd.DataFrame(stats)
stats_df.to_csv('training_stats.csv', index=False)
logging.info(f"Training statistics saved to training_stats.csv")
except Exception as e:
logging.error(f"Failed to save training statistics: {e}")
logging.error(traceback.format_exc())
# Close exchange if it's still open
if exchange:
try:
# Check if exchange has the close method (ccxt.async_support)
if hasattr(exchange, 'close'):
await exchange.close()
logging.info("Closed exchange connection")
else:
logging.info("Exchange doesn't have close method (standard ccxt), skipping close")
except Exception as e:
logging.error(f"Error closing exchange: {e}")
return stats
def plot_training_results(stats):
"""Plot training results and save to file"""
try:
# Check if we have data to plot
if not stats or len(stats.get('episode_rewards', [])) == 0:
logger.warning("No training data to plot")
return
# Create a DataFrame with consistent lengths
max_len = max(len(stats.get(key, [])) for key in stats)
# Ensure all arrays have the same length by padding with the last value or zeros
processed_stats = {}
for key, values in stats.items():
if not values: # Skip empty lists
continue
# Pad arrays to the same length
if len(values) < max_len:
if len(values) > 0:
# Pad with the last value
values = values + [values[-1]] * (max_len - len(values))
else:
# Pad with zeros
values = [0] * max_len
processed_stats[key] = values[:max_len] # Trim if longer
# Create DataFrame
df = pd.DataFrame(processed_stats)
# Add episode column
df['episode'] = range(1, len(df) + 1)
# Create figure with subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
# Plot episode rewards
if 'episode_rewards' in df.columns:
axes[0, 0].plot(df['episode'], df['episode_rewards'])
axes[0, 0].set_title('Episode Rewards')
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].grid(True)
# Plot account balance
if 'balances' in df.columns:
axes[0, 1].plot(df['episode'], df['balances'])
axes[0, 1].set_title('Account Balance')
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Balance ($)')
axes[0, 1].grid(True)
# Plot win rate
if 'win_rates' in df.columns:
axes[1, 0].plot(df['episode'], df['win_rates'])
axes[1, 0].set_title('Win Rate')
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Win Rate')
axes[1, 0].set_ylim([0, 1])
axes[1, 0].grid(True)
# Plot episode PnL
if 'episode_pnls' in df.columns:
axes[1, 1].plot(df['episode'], df['episode_pnls'])
axes[1, 1].set_title('Episode PnL')
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('PnL ($)')
axes[1, 1].grid(True)
# Plot cumulative PnL
if 'cumulative_pnl' in df.columns:
axes[2, 0].plot(df['episode'], df['cumulative_pnl'])
axes[2, 0].set_title('Cumulative PnL')
axes[2, 0].set_xlabel('Episode')
axes[2, 0].set_ylabel('Cumulative PnL ($)')
axes[2, 0].grid(True)
# Plot maximum drawdown
if 'drawdowns' in df.columns:
axes[2, 1].plot(df['episode'], df['drawdowns'])
axes[2, 1].set_title('Maximum Drawdown')
axes[2, 1].set_xlabel('Episode')
axes[2, 1].set_ylabel('Drawdown')
axes[2, 1].grid(True)
# Adjust layout
plt.tight_layout()
# Save figure
plt.savefig('training_results.png')
logger.info("Training results saved to training_results.png")
# Save statistics to CSV
df.to_csv('training_stats.csv', index=False)
logger.info("Training statistics saved to training_stats.csv")
except Exception as e:
logger.error(f"Error plotting training results: {e}")
logger.error(traceback.format_exc())
def evaluate_agent(agent, env, num_episodes=10):
"""Evaluate the agent on test data"""
total_reward = 0
total_profit = 0
total_trades = 0
winning_trades = 0
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
initial_balance = env.balance
done = False
while not done:
# Select action (no exploration)
action = agent.select_action(state, training=False)
next_state, reward, done, info = env.step(action)
state = next_state
episode_reward += reward
total_reward += episode_reward
total_profit += env.balance - initial_balance
# Count trades and wins
for trade in env.trades:
if 'pnl_percent' in trade:
total_trades += 1
if trade['pnl_percent'] > 0:
winning_trades += 1
# Calculate averages
avg_reward = total_reward / num_episodes
avg_profit = total_profit / num_episodes
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
f"Win Rate={win_rate:.1f}%")
return avg_reward, avg_profit, win_rate
async def test_training():
"""Test the training process with a small number of episodes"""
logger.info("Starting training tests...")
# Initialize exchange
exchange = ccxt.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True,
})
try:
# Create environment with small initial balance for testing
env = TradingEnvironment(
exchange=exchange,
symbol="ETH/USDT",
timeframe="1m",
leverage=MAX_LEVERAGE,
initial_balance=100, # Small balance for testing
demo=True # Always use demo mode for testing
)
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
# Run a few test episodes
test_episodes = 3
logger.info(f"Running {test_episodes} test episodes...")
for episode in range(test_episodes):
state = env.reset()
episode_reward = 0
done = False
step = 0
while not done and step < 100: # Limit steps for testing
# Select action
action = agent.select_action(state)
# Take action
next_state, reward, done, info = env.step(action)
# Store experience
agent.memory.push(state, action, reward, next_state, done)
# Learn
loss = agent.learn()
state = next_state
episode_reward += reward
step += 1
# Print progress
if step % 10 == 0:
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
# Test model saving
try:
agent.save("models/test_model.pt")
logger.info("Successfully saved model")
except Exception as e:
logger.error(f"Error saving model: {e}")
logger.info("Training tests completed successfully")
return True
except Exception as e:
logger.error(f"Training test failed: {e}")
return False
finally:
await exchange.close()
async def initialize_exchange():
"""Initialize the exchange connection"""
try:
# Try to initialize with async support first
try:
exchange = ccxt.pro.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True
})
logger.info(f"Exchange initialized with async support: {exchange.id}")
except (AttributeError, ImportError):
# Fall back to standard CCXT
exchange = ccxt.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True
})
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
return exchange
except Exception as e:
logger.error(f"Failed to initialize exchange: {e}")
raise
async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""Fetch historical OHLCV data from the exchange"""
try:
logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}")
# Use the refactored fetch method
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
if not data:
logger.warning("No historical data received")
return data
except Exception as e:
logger.error(f"Failed to fetch historical data: {e}")
return []
async def live_trading(
symbol="ETH/USDT",
timeframe="1m",
model_path=None,
demo=False,
leverage=50,
initial_balance=1000,
max_position_size=0.1,
commission=0.0004,
window_size=30,
update_interval=60,
stop_loss_pct=0.02,
take_profit_pct=0.04,
max_trades_per_day=10,
risk_per_trade=0.02,
use_trailing_stop=False,
trailing_stop_callback=0.005,
use_dynamic_sizing=True,
use_volatility_sizing=True,
use_multi_timeframe=True,
use_sentiment=False,
use_limit_orders=False,
use_dollar_cost_avg=False,
use_grid_trading=False,
use_martingale=False,
use_anti_martingale=False,
use_custom_indicators=True,
use_ml_predictions=True,
use_ensemble=True,
use_reinforcement=True,
use_risk_management=True,
use_portfolio_management=False,
use_position_sizing=True,
use_stop_loss=True,
use_take_profit=True,
use_trailing_stop_loss=False,
use_dynamic_stop_loss=True,
use_dynamic_take_profit=True,
use_dynamic_trailing_stop=False,
use_dynamic_position_sizing=True,
use_dynamic_leverage=False,
use_dynamic_risk_per_trade=True,
use_dynamic_max_trades_per_day=False,
use_dynamic_update_interval=False,
use_dynamic_window_size=False,
use_dynamic_commission=False,
use_dynamic_timeframe=False,
use_dynamic_symbol=False,
use_dynamic_model_path=False,
use_dynamic_demo=False,
use_dynamic_leverage_value=False,
use_dynamic_initial_balance=False,
use_dynamic_max_position_size=False,
use_dynamic_stop_loss_pct=False,
use_dynamic_take_profit_pct=False,
use_dynamic_risk_per_trade_value=False,
use_dynamic_trailing_stop_callback=False,
use_dynamic_use_trailing_stop=False,
use_dynamic_use_dynamic_sizing=False,
use_dynamic_use_volatility_sizing=False,
use_dynamic_use_multi_timeframe=False,
use_dynamic_use_sentiment=False,
use_dynamic_use_limit_orders=False,
use_dynamic_use_dollar_cost_avg=False,
use_dynamic_use_grid_trading=False,
use_dynamic_use_martingale=False,
use_dynamic_use_anti_martingale=False,
use_dynamic_use_custom_indicators=False,
use_dynamic_use_ml_predictions=False,
use_dynamic_use_ensemble=False,
use_dynamic_use_reinforcement=False,
use_dynamic_use_risk_management=False,
use_dynamic_use_portfolio_management=False,
use_dynamic_use_position_sizing=False,
use_dynamic_use_stop_loss=False,
use_dynamic_use_take_profit=False,
use_dynamic_use_trailing_stop_loss=False,
use_dynamic_use_dynamic_stop_loss=False,
use_dynamic_use_dynamic_take_profit=False,
use_dynamic_use_dynamic_trailing_stop=False,
use_dynamic_use_dynamic_position_sizing=False,
use_dynamic_use_dynamic_leverage=False,
use_dynamic_use_dynamic_risk_per_trade=False,
use_dynamic_use_dynamic_max_trades_per_day=False,
use_dynamic_use_dynamic_update_interval=False,
use_dynamic_use_dynamic_window_size=False,
use_dynamic_use_dynamic_commission=False,
use_dynamic_use_dynamic_timeframe=False,
use_dynamic_use_dynamic_symbol=False,
use_dynamic_use_dynamic_model_path=False,
use_dynamic_use_dynamic_demo=False,
use_dynamic_use_dynamic_leverage_value=False,
use_dynamic_use_dynamic_initial_balance=False,
use_dynamic_use_dynamic_max_position_size=False,
use_dynamic_use_dynamic_stop_loss_pct=False,
use_dynamic_use_dynamic_take_profit_pct=False,
use_dynamic_use_dynamic_risk_per_trade_value=False,
use_dynamic_use_dynamic_trailing_stop_callback=False,
):
"""
Live trading function that connects to the exchange and trades in real-time.
Args:
symbol: Trading pair symbol
timeframe: Timeframe for trading
model_path: Path to the trained model
demo: Whether to use demo mode (sandbox)
leverage: Leverage to use
initial_balance: Initial balance
max_position_size: Maximum position size as a percentage of balance
commission: Commission rate
window_size: Window size for the environment
update_interval: Interval to update data in seconds
stop_loss_pct: Stop loss percentage
take_profit_pct: Take profit percentage
max_trades_per_day: Maximum trades per day
risk_per_trade: Risk per trade as a percentage of balance
use_trailing_stop: Whether to use trailing stop
trailing_stop_callback: Trailing stop callback percentage
use_dynamic_sizing: Whether to use dynamic position sizing
use_volatility_sizing: Whether to use volatility-based position sizing
use_multi_timeframe: Whether to use multi-timeframe analysis
use_sentiment: Whether to use sentiment analysis
use_limit_orders: Whether to use limit orders
use_dollar_cost_avg: Whether to use dollar cost averaging
use_grid_trading: Whether to use grid trading
use_martingale: Whether to use martingale strategy
use_anti_martingale: Whether to use anti-martingale strategy
use_custom_indicators: Whether to use custom indicators
use_ml_predictions: Whether to use ML predictions
use_ensemble: Whether to use ensemble methods
use_reinforcement: Whether to use reinforcement learning
use_risk_management: Whether to use risk management
use_portfolio_management: Whether to use portfolio management
use_position_sizing: Whether to use position sizing
use_stop_loss: Whether to use stop loss
use_take_profit: Whether to use take profit
use_trailing_stop_loss: Whether to use trailing stop loss
use_dynamic_stop_loss: Whether to use dynamic stop loss
use_dynamic_take_profit: Whether to use dynamic take profit
use_dynamic_trailing_stop: Whether to use dynamic trailing stop
use_dynamic_position_sizing: Whether to use dynamic position sizing
use_dynamic_leverage: Whether to use dynamic leverage
use_dynamic_risk_per_trade: Whether to use dynamic risk per trade
use_dynamic_max_trades_per_day: Whether to use dynamic max trades per day
use_dynamic_update_interval: Whether to use dynamic update interval
use_dynamic_window_size: Whether to use dynamic window size
use_dynamic_commission: Whether to use dynamic commission
use_dynamic_timeframe: Whether to use dynamic timeframe
use_dynamic_symbol: Whether to use dynamic symbol
use_dynamic_model_path: Whether to use dynamic model path
use_dynamic_demo: Whether to use dynamic demo
use_dynamic_leverage_value: Whether to use dynamic leverage value
use_dynamic_initial_balance: Whether to use dynamic initial balance
use_dynamic_max_position_size: Whether to use dynamic max position size
use_dynamic_stop_loss_pct: Whether to use dynamic stop loss percentage
use_dynamic_take_profit_pct: Whether to use dynamic take profit percentage
use_dynamic_risk_per_trade_value: Whether to use dynamic risk per trade value
use_dynamic_trailing_stop_callback: Whether to use dynamic trailing stop callback
"""
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
logger.info(f"Demo mode: {demo}, Leverage: {leverage}x")
# Flag to track if we're using mock trading
using_mock_trading = False
# Initialize exchange
try:
exchange = await initialize_exchange()
# Try to set sandbox mode if demo is True
if demo:
try:
exchange.set_sandbox_mode(demo)
logger.info(f"Sandbox mode set to {demo}")
except Exception as e:
logger.warning(f"Exchange doesn't support sandbox mode: {e}")
logger.info("Continuing in mock trading mode instead")
using_mock_trading = True
# Set leverage
if not demo or using_mock_trading:
try:
await exchange.set_leverage(leverage, symbol)
logger.info(f"Leverage set to {leverage}x")
except Exception as e:
logger.warning(f"Failed to set leverage: {e}")
# Initialize environment
env = TradingEnvironment(
initial_balance=initial_balance,
leverage=leverage,
window_size=window_size,
commission=commission,
symbol=symbol,
timeframe=timeframe,
max_position_size=max_position_size,
stop_loss_pct=stop_loss_pct,
take_profit_pct=take_profit_pct,
max_trades_per_day=max_trades_per_day,
risk_per_trade=risk_per_trade,
use_trailing_stop=use_trailing_stop,
trailing_stop_callback=trailing_stop_callback,
use_dynamic_sizing=use_dynamic_sizing,
use_volatility_sizing=use_volatility_sizing,
use_multi_timeframe=use_multi_timeframe,
use_sentiment=use_sentiment,
use_limit_orders=use_limit_orders,
use_dollar_cost_avg=use_dollar_cost_avg,
use_grid_trading=use_grid_trading,
use_martingale=use_martingale,
use_anti_martingale=use_anti_martingale,
use_custom_indicators=use_custom_indicators,
use_ml_predictions=use_ml_predictions,
use_ensemble=use_ensemble,
use_reinforcement=use_reinforcement,
use_risk_management=use_risk_management,
use_portfolio_management=use_portfolio_management,
use_position_sizing=use_position_sizing,
use_stop_loss=use_stop_loss,
use_take_profit=use_take_profit,
use_trailing_stop_loss=use_trailing_stop_loss,
use_dynamic_stop_loss=use_dynamic_stop_loss,
use_dynamic_take_profit=use_dynamic_take_profit,
use_dynamic_trailing_stop=use_dynamic_trailing_stop,
use_dynamic_position_sizing=use_dynamic_position_sizing,
use_dynamic_leverage=use_dynamic_leverage,
use_dynamic_risk_per_trade=use_dynamic_risk_per_trade,
use_dynamic_max_trades_per_day=use_dynamic_max_trades_per_day,
use_dynamic_update_interval=use_dynamic_update_interval,
use_dynamic_window_size=use_dynamic_window_size,
use_dynamic_commission=use_dynamic_commission,
use_dynamic_timeframe=use_dynamic_timeframe,
use_dynamic_symbol=use_dynamic_symbol,
use_dynamic_model_path=use_dynamic_model_path,
use_dynamic_demo=use_dynamic_demo,
use_dynamic_leverage_value=use_dynamic_leverage_value,
use_dynamic_initial_balance=use_dynamic_initial_balance,
use_dynamic_max_position_size=use_dynamic_max_position_size,
use_dynamic_stop_loss_pct=use_dynamic_stop_loss_pct,
use_dynamic_take_profit_pct=use_dynamic_take_profit_pct,
use_dynamic_risk_per_trade_value=use_dynamic_risk_per_trade_value,
use_dynamic_trailing_stop_callback=use_dynamic_trailing_stop_callback,
use_dynamic_use_trailing_stop=use_dynamic_use_trailing_stop,
use_dynamic_use_dynamic_sizing=use_dynamic_use_dynamic_sizing,
use_dynamic_use_volatility_sizing=use_dynamic_use_volatility_sizing,
use_dynamic_use_multi_timeframe=use_dynamic_use_multi_timeframe,
use_dynamic_use_sentiment=use_dynamic_use_sentiment,
use_dynamic_use_limit_orders=use_dynamic_use_limit_orders,
use_dynamic_use_dollar_cost_avg=use_dynamic_use_dollar_cost_avg,
use_dynamic_use_grid_trading=use_dynamic_use_grid_trading,
use_dynamic_use_martingale=use_dynamic_use_martingale,
use_dynamic_use_anti_martingale=use_dynamic_use_anti_martingale,
use_dynamic_use_custom_indicators=use_dynamic_use_custom_indicators,
use_dynamic_use_ml_predictions=use_dynamic_use_ml_predictions,
use_dynamic_use_ensemble=use_dynamic_use_ensemble,
use_dynamic_use_reinforcement=use_dynamic_use_reinforcement,
use_dynamic_use_risk_management=use_dynamic_use_risk_management,
use_dynamic_use_portfolio_management=use_dynamic_use_portfolio_management,
use_dynamic_use_position_sizing=use_dynamic_use_position_sizing,
use_dynamic_use_stop_loss=use_dynamic_use_stop_loss,
use_dynamic_use_take_profit=use_dynamic_use_take_profit,
use_dynamic_use_trailing_stop_loss=use_dynamic_use_trailing_stop_loss,
use_dynamic_use_dynamic_stop_loss=use_dynamic_use_dynamic_stop_loss,
use_dynamic_use_dynamic_take_profit=use_dynamic_use_dynamic_take_profit,
use_dynamic_use_dynamic_trailing_stop=use_dynamic_use_dynamic_trailing_stop,
use_dynamic_use_dynamic_position_sizing=use_dynamic_use_dynamic_position_sizing,
use_dynamic_use_dynamic_leverage=use_dynamic_use_dynamic_leverage,
use_dynamic_use_dynamic_risk_per_trade=use_dynamic_use_dynamic_risk_per_trade,
use_dynamic_use_dynamic_max_trades_per_day=use_dynamic_use_dynamic_max_trades_per_day,
use_dynamic_use_dynamic_update_interval=use_dynamic_use_dynamic_update_interval,
use_dynamic_use_dynamic_window_size=use_dynamic_use_dynamic_window_size,
use_dynamic_use_dynamic_commission=use_dynamic_use_dynamic_commission,
use_dynamic_use_dynamic_timeframe=use_dynamic_use_dynamic_timeframe,
use_dynamic_use_dynamic_symbol=use_dynamic_use_dynamic_symbol,
use_dynamic_use_dynamic_model_path=use_dynamic_use_dynamic_model_path,
use_dynamic_use_dynamic_demo=use_dynamic_use_dynamic_demo,
use_dynamic_use_dynamic_leverage_value=use_dynamic_use_dynamic_leverage_value,
use_dynamic_use_dynamic_initial_balance=use_dynamic_use_dynamic_initial_balance,
use_dynamic_use_dynamic_max_position_size=use_dynamic_use_dynamic_max_position_size,
use_dynamic_use_dynamic_stop_loss_pct=use_dynamic_use_dynamic_stop_loss_pct,
use_dynamic_use_dynamic_take_profit_pct=use_dynamic_use_dynamic_take_profit_pct,
use_dynamic_use_dynamic_risk_per_trade_value=use_dynamic_use_dynamic_risk_per_trade_value,
use_dynamic_use_dynamic_trailing_stop_callback=use_dynamic_use_dynamic_trailing_stop_callback,
)
# Fetch initial data
logger.info(f"Fetching initial data for {symbol}")
await fetch_and_update_data(exchange, env, symbol, timeframe)
# Initialize agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
# Load model if provided
if model_path:
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
# Initialize TensorBoard writer
agent.writer = SummaryWriter(log_dir=f"runs/live_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
# Initialize trading statistics
trades = []
total_pnl = 0
win_count = 0
loss_count = 0
# Initialize trading log file
log_file = f"live_trading_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
with open(log_file, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n")
# Start live trading loop
logger.info(f"Starting live trading with {symbol} on {timeframe} timeframe")
# Main trading loop
step_counter = 0
last_update_time = time.time()
while True:
# Get current state
state = env.get_state()
# Select action
action = agent.select_action(state, training=False)
# Take action
next_state, reward, done, info = env.step(action)
# Log action and results
if info.get('trade_executed', False):
trade_data = {
'timestamp': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'action': info['action'],
'price': env.current_price,
'position_size': env.position_size,
'balance': env.balance,
'pnl': env.last_trade_profit
}
trades.append(trade_data)
# Update statistics
if env.last_trade_profit > 0:
win_count += 1
total_pnl += env.last_trade_profit
else:
loss_count += 1
# Log trade to file
with open(log_file, 'a') as f:
f.write(f"{trade_data['timestamp']},{trade_data['action']},{trade_data['price']},{trade_data['position_size']},{trade_data['balance']},{trade_data['pnl']}\n")
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
# Update TensorBoard metrics
if step_counter % 10 == 0:
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
agent.writer.add_scalar('Live/Reward', reward, step_counter)
# Check if it's time to update data
current_time = time.time()
if current_time - last_update_time > update_interval:
await fetch_and_update_data(exchange, env, symbol, timeframe)
last_update_time = current_time
# Print status update
win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0
logger.info(f"""
Step: {step_counter}
Balance: ${env.balance:.2f}
Total PnL: ${env.total_pnl:.2f}
Win Rate: {win_rate:.2f}
Trades: {len(trades)}
""")
# Move to next state
state = next_state
step_counter += 1
# Sleep to avoid excessive API calls
await asyncio.sleep(1)
# Check for manual stop
if done:
break
# Close TensorBoard writer
agent.writer.close()
# Save final statistics
win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0
logger.info(f"""
Live Trading Summary:
Total Steps: {step_counter}
Final Balance: ${env.balance:.2f}
Total PnL: ${env.total_pnl:.2f}
Win Rate: {win_rate:.2f}
Total Trades: {len(trades)}
""")
# Close exchange connection
try:
await exchange.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.warning(f"Error closing exchange connection: {e}")
except Exception as e:
logger.error(f"Error in live trading: {e}")
logger.error(traceback.format_exc())
try:
await exchange.close()
except:
pass
logger.info("Exchange connection closed")
async def get_latest_candle(exchange, symbol):
"""
Get the latest candle for a symbol.
Args:
exchange: Exchange instance
symbol: Trading pair symbol
Returns:
Latest candle data or None on failure
"""
try:
# Use the refactored fetch method with limit=1
data = await fetch_ohlcv_data(exchange, symbol, "1m", 1)
if data and len(data) > 0:
return data[0]
else:
logger.warning("No candle data received")
return None
except Exception as e:
logger.error(f"Failed to fetch latest candle: {e}")
return None
async def fetch_ohlcv_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""
Fetch OHLCV data from exchange with error handling and retry logic.
Args:
exchange: The exchange instance
symbol: Trading pair symbol
timeframe: Candle timeframe
limit: Number of candles to fetch
Returns:
List of candle data or empty list on failure
"""
max_retries = 3
retry_delay = 5
for attempt in range(max_retries):
try:
logging.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt {attempt+1}/{max_retries})")
# Check if exchange has fetch_ohlcv method
if not hasattr(exchange, 'fetch_ohlcv'):
logging.error("Exchange does not support OHLCV data fetching")
return []
# Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor
try:
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, limit=limit)
else:
# Run in executor to avoid blocking
loop = asyncio.get_event_loop()
ohlcv = await loop.run_in_executor(
None,
lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
)
except Exception as e:
logging.error(f"Failed to fetch OHLCV data: {e}")
await asyncio.sleep(retry_delay)
continue
if not ohlcv or len(ohlcv) == 0:
logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})")
await asyncio.sleep(retry_delay)
continue
# Convert to list of lists format
data = []
for candle in ohlcv:
timestamp, open_price, high, low, close, volume = candle
data.append([timestamp, open_price, high, low, close, volume])
logging.info(f"Successfully fetched {len(data)} candles")
return data
except Exception as e:
logging.error(f"Error fetching OHLCV data (attempt {attempt+1}/{max_retries}): {e}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
logging.error(f"Failed to fetch OHLCV data after {max_retries} attempts")
return []
# Add this near the top of the file, after imports
def ensure_pytorch_compatibility():
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
try:
import torch
from torch.serialization import add_safe_globals
import numpy as np
# Add numpy scalar to safe globals for PyTorch 2.6+
add_safe_globals(['numpy._core.multiarray.scalar'])
logger.info("Added numpy scalar to PyTorch safe globals")
except (ImportError, AttributeError) as e:
logger.warning(f"Could not configure PyTorch compatibility: {e}")
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
# Call this function at the start of the main function
async def main():
# Ensure PyTorch compatibility
ensure_pytorch_compatibility()
parser = argparse.ArgumentParser(description='Trading Bot')
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
help='Operation mode: train, eval, or live')
parser.add_argument('--episodes', type=int, default=1000,
help='Number of episodes for training or evaluation')
parser.add_argument('--max_steps', type=int, default=1000,
help='Maximum steps per episode for training')
parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true',
help='Run in demo mode (paper trading) if true')
parser.add_argument('--symbol', type=str, default='ETH/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m',
help='Candle timeframe (1m, 5m, 15m, 1h, etc.)')
parser.add_argument('--leverage', type=int, default=50,
help='Leverage for futures trading')
parser.add_argument('--model', type=str, default='models/trading_agent_best_net_pnl.pt',
help='Path to model file for evaluation or live trading')
parser.add_argument('--compact_save', action='store_true',
help='Use compact model saving (for low disk space)')
args = parser.parse_args()
# Convert string boolean to actual boolean
demo_mode = args.demo.lower() == 'true'
# Get device (GPU or CPU)
device = get_device()
exchange = None
try:
# Initialize exchange
exchange = await initialize_exchange()
# Create environment with updated parameters
env = TradingEnvironment(
initial_balance=INITIAL_BALANCE,
window_size=30,
leverage=args.leverage,
exchange_id='mexc',
symbol=args.symbol,
timeframe=args.timeframe
)
if args.mode == 'train':
# Fetch initial data for training
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Train the agent
logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes,
max_steps_per_episode=args.max_steps,
use_compact_save=args.compact_save)
elif args.mode == 'eval' or args.mode == 'live':
# Fetch initial data for the specified symbol and timeframe
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Determine model path
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
# Create agent with default parameters
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Try to load the model
try:
# Add numpy scalar to safe globals before loading
import numpy as np
from torch.serialization import add_safe_globals
# Add numpy scalar to safe globals
add_safe_globals(['numpy._core.multiarray.scalar'])
# Load the model
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# Ask user if they want to continue with a new model
if args.mode == 'live':
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
if confirmation.lower() != 'y':
logger.info("Live trading canceled by user")
return
logger.info("Continuing with a new model")
else:
logger.info("Continuing evaluation with a new model")
if args.mode == 'eval':
# Evaluate the agent
logger.info("Evaluating agent...")
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'live':
# Start live trading
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
await live_trading(
agent=agent,
env=env,
exchange=exchange,
symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage
)
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
# Clean up exchange connection
if exchange:
try:
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
await exchange.client.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.warning(f"Could not properly close exchange connection: {e}")
# Add this function near the top with other utility functions
def create_candlestick_figure(data, trades=None, title="Trading Chart"):
"""Create a candlestick chart with trades marked"""
try:
if data is None or len(data) < 5:
logger.warning("Not enough data for candlestick chart")
return None
# Convert data to DataFrame if it's not already
if not isinstance(data, pd.DataFrame):
df = pd.DataFrame(data)
else:
df = data.copy()
# Ensure required columns exist
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
for col in required_columns:
if col not in df.columns:
logger.warning(f"Missing required column {col} for candlestick chart")
return None
# Format dates
if 'timestamp' in df.columns:
if isinstance(df['timestamp'].iloc[0], (int, float)):
# Convert timestamp to datetime if it's numeric
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
# Set timestamp as index if it's not already
if df.index.name != 'timestamp':
df.set_index('timestamp', inplace=True)
# Rename columns for mplfinance
df_mpf = df.copy()
if 'open' in df_mpf.columns:
df_mpf = df_mpf.rename(columns={
'open': 'Open',
'high': 'High',
'low': 'Low',
'close': 'Close',
'volume': 'Volume'
})
# Create a simple matplotlib figure instead
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8),
gridspec_kw={'height_ratios': [3, 1]})
# Plot candlesticks manually
for i in range(len(df_mpf)):
# Get date and prices
date = df_mpf.index[i]
open_price = df_mpf['Open'].iloc[i]
high_price = df_mpf['High'].iloc[i]
low_price = df_mpf['Low'].iloc[i]
close_price = df_mpf['Close'].iloc[i]
# Determine color based on price movement
color = 'green' if close_price >= open_price else 'red'
# Plot candle body
body_height = abs(close_price - open_price)
body_bottom = min(close_price, open_price)
ax1.bar(date, body_height, bottom=body_bottom, width=0.6,
color=color, alpha=0.6)
# Plot wick
ax1.plot([date, date], [low_price, high_price], color=color, linewidth=1)
# Plot volume
ax2.bar(df_mpf.index, df_mpf['Volume'], width=0.6, color='blue', alpha=0.5)
# Mark trades if available
if trades and len(trades) > 0:
for trade in trades:
if 'timestamp' not in trade or 'type' not in trade or 'price' not in trade:
continue
# Convert timestamp to datetime if needed
if isinstance(trade['timestamp'], (int, float)):
trade_time = pd.to_datetime(trade['timestamp'], unit='ms')
else:
trade_time = trade['timestamp']
# Determine marker color based on trade type
marker_color = 'green' if trade['type'].lower() == 'buy' else 'red'
# Add marker at trade price
ax1.scatter(trade_time, trade['price'], color=marker_color,
marker='^' if trade['type'].lower() == 'buy' else 'v',
s=100, zorder=5)
# Set title and labels
ax1.set_title(title)
ax1.set_ylabel('Price')
ax2.set_ylabel('Volume')
ax1.grid(True)
ax2.grid(True)
# Format x-axis
plt.setp(ax1.get_xticklabels(), visible=False)
# Adjust layout
plt.tight_layout()
return fig
except Exception as e:
logger.error(f"Error creating candlestick figure: {e}")
return None
class CandlePatternCNN(nn.Module):
"""Convolutional neural network for detecting candlestick patterns"""
def __init__(self, input_channels=5, feature_dimension=512):
super(CandlePatternCNN, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.relu3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(kernel_size=2)
# Projection layers
self.fc1 = nn.Linear(128 * 4 * 4, 1024)
self.relu4 = nn.ReLU()
self.fc2 = nn.Linear(1024, feature_dimension)
# Initialize intermediate features as empty tensors, not as a dict
# This makes the model TorchScript compatible
self.feature_1m = torch.zeros(1, feature_dimension)
self.feature_1h = torch.zeros(1, feature_dimension)
self.feature_1d = torch.zeros(1, feature_dimension)
def forward(self, x_1m, x_1h, x_1d):
# Process 1m data
feat_1m = self.process_timeframe(x_1m)
# Process 1h data
feat_1h = self.process_timeframe(x_1h)
# Process 1d data
feat_1d = self.process_timeframe(x_1d)
# Store features as attributes instead of in a dictionary
self.feature_1m = feat_1m
self.feature_1h = feat_1h
self.feature_1d = feat_1d
# Concatenate features from different timeframes
combined_features = torch.cat([feat_1m, feat_1h, feat_1d], dim=1)
return combined_features
def process_timeframe(self, x):
"""Process a single timeframe batch of data"""
# Ensure proper shape for input, handle both batched and single inputs
if len(x.shape) == 3: # Single input, shape: [channels, height, width]
x = x.unsqueeze(0) # Add batch dimension
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.pool3(self.relu3(self.conv3(x)))
# Flatten the spatial dimensions for the fully connected layer
x = x.view(x.size(0), -1)
x = self.relu4(self.fc1(x))
x = self.fc2(x)
return x
def get_features(self):
"""Return features for each timeframe"""
# Use properties instead of dict for TorchScript compatibility
return self.feature_1m, self.feature_1h, self.feature_1d
# Add candle cache system
class CandleCache:
"""
Cache system for candles of different timeframes.
Reduces API calls by storing and updating candle data.
"""
def __init__(self):
self.candles = {
'1m': [],
'1h': [],
'1d': []
}
self.last_updated = {
'1m': None,
'1h': None,
'1d': None
}
# Add ticks channel for real-time data (WebSocket)
self.ticks = []
self.last_tick_time = None
def add_candles(self, timeframe, new_candles):
"""Add new candles to the cache"""
if not self.candles[timeframe]:
self.candles[timeframe] = new_candles
else:
# Find the last timestamp in our current cache
last_timestamp = self.candles[timeframe][-1][0]
# Add only candles newer than our last cached one
for candle in new_candles:
if candle[0] > last_timestamp:
self.candles[timeframe].append(candle)
self.last_updated[timeframe] = datetime.datetime.now()
def add_tick(self, tick_data):
"""Add a new tick to the ticks buffer"""
self.ticks.append(tick_data)
self.last_tick_time = datetime.datetime.now()
# Keep only the most recent 1000 ticks to prevent memory issues
if len(self.ticks) > 1000:
self.ticks = self.ticks[-1000:]
def get_ticks(self, limit=None):
"""Get the most recent ticks from the buffer"""
if not self.ticks:
return []
if limit and limit > 0:
return self.ticks[-limit:]
return self.ticks
def get_candles(self, timeframe, limit=300):
"""Get the most recent candles for a timeframe"""
if not self.candles[timeframe]:
return []
return self.candles[timeframe][-limit:]
def needs_update(self, timeframe, max_age_seconds):
"""Check if the cache needs to be updated"""
if not self.last_updated[timeframe]:
return True
age = (datetime.datetime.now() - self.last_updated[timeframe]).total_seconds()
return age > max_age_seconds
async def fetch_multi_timeframe_data(exchange, symbol, candle_cache):
"""Fetch candle data for multiple timeframes, using cache when possible"""
update_intervals = {
'1m': 60, # Update every 1 minute
'1h': 3600, # Update every 1 hour
'1d': 86400 # Update every 1 day
}
# TODO: For 1s/tick timeframes, we'll implement the exchange's WebSocket API
# for real-time data streaming in the future. This will enable ultra-low latency
# trading signals with minimal delay between market data reception and action execution.
# A WebSocket implementation is already prepared in the RealTimeDataStream class.
limits = {
'1m': 1000,
'1h': 500,
'1d': 300
}
for timeframe, interval in update_intervals.items():
if candle_cache.needs_update(timeframe, interval):
try:
logging.info(f"Fetching {timeframe} candle data for {symbol}")
candles = await fetch_ohlcv_data(exchange, symbol, timeframe, limits[timeframe])
candle_cache.add_candles(timeframe, candles)
logging.info(f"Fetched {len(candles)} {timeframe} candles")
except Exception as e:
logging.error(f"Error fetching {timeframe} candle data: {e}")
return {
'1m': candle_cache.get_candles('1m'),
'1h': candle_cache.get_candles('1h'),
'1d': candle_cache.get_candles('1d')
}
# Modify the LSTMAttentionDQN class to incorporate the CNN features
class LSTMAttentionDQN(nn.Module):
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
super(LSTMAttentionDQN, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size
self.lstm_layers = lstm_layers
self.attention_heads = attention_heads
# LSTM layer
self.lstm = nn.LSTM(
input_size=state_size,
hidden_size=hidden_size,
num_layers=lstm_layers,
batch_first=True,
dropout=0.2 if lstm_layers > 1 else 0
)
# Multi-head self-attention
self.attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=attention_heads,
dropout=0.1
)
# Value stream
self.value_stream = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
# Advantage stream
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Linear(128, action_size)
)
# Fusion for multi-timeframe data
self.cnn_fusion = nn.Sequential(
nn.Linear(512 * 3, 1024), # 512 features from each of the 3 timeframes
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, hidden_size)
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LSTM):
for name, param in module.named_parameters():
if 'weight' in name:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.constant_(param, 0)
def forward(self, state, x_1m=None, x_1h=None, x_1d=None):
"""
Forward pass handling different input shapes and optional CNN features
Args:
state: Primary state vector (batch_size, sequence_length, state_size)
x_1m, x_1h, x_1d: Optional CNN features from different timeframes
Returns:
Q-values for each action
"""
batch_size = state.size(0)
# Handle CNN features if provided
if x_1m is not None and x_1h is not None and x_1d is not None:
# Ensure all CNN features have batch dimension
if len(x_1m.shape) == 2:
x_1m = x_1m.unsqueeze(0)
if len(x_1h.shape) == 2:
x_1h = x_1h.unsqueeze(0)
if len(x_1d.shape) == 2:
x_1d = x_1d.unsqueeze(0)
# Ensure batch dimensions match
if x_1m.size(0) != batch_size:
x_1m = x_1m.expand(batch_size, -1, -1) if x_1m.size(0) == 1 else x_1m[:batch_size]
if x_1h.size(0) != batch_size:
x_1h = x_1h.expand(batch_size, -1, -1) if x_1h.size(0) == 1 else x_1h[:batch_size]
if x_1d.size(0) != batch_size:
x_1d = x_1d.expand(batch_size, -1, -1) if x_1d.size(0) == 1 else x_1d[:batch_size]
# Check dimensions before concatenation
if x_1m.dim() == 3 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512:
# Already in correct format [batch, features]
cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1)
elif x_1m.dim() == 2 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512:
# Dimensions correct but missing batch dimension
cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1).unsqueeze(0)
else:
# Reshape to ensure correct dimensions
x_1m_flat = x_1m.reshape(batch_size, -1)
x_1h_flat = x_1h.reshape(batch_size, -1)
x_1d_flat = x_1d.reshape(batch_size, -1)
# Handle variable dimensions more gracefully
needed_features = 512
if x_1m_flat.size(1) < needed_features:
x_1m_flat = F.pad(x_1m_flat, (0, needed_features - x_1m_flat.size(1)))
else:
x_1m_flat = x_1m_flat[:, :needed_features]
if x_1h_flat.size(1) < needed_features:
x_1h_flat = F.pad(x_1h_flat, (0, needed_features - x_1h_flat.size(1)))
else:
x_1h_flat = x_1h_flat[:, :needed_features]
if x_1d_flat.size(1) < needed_features:
x_1d_flat = F.pad(x_1d_flat, (0, needed_features - x_1d_flat.size(1)))
else:
x_1d_flat = x_1d_flat[:, :needed_features]
# Concatenate
cnn_combined = torch.cat([x_1m_flat, x_1h_flat, x_1d_flat], dim=1)
# Use CNN fusion network to reduce dimension
cnn_features = self.cnn_fusion(cnn_combined)
# Reshape to match LSTM input shape
cnn_features = cnn_features.view(batch_size, 1, self.hidden_size)
# Combine with state input by concatenating along sequence dimension
if state.dim() < 3:
# If state is 2D [batch, features], reshape to 3D [batch, 1, features]
state = state.unsqueeze(1)
# Ensure state has proper dimensions
if state.size(2) != self.state_size:
# If state dimension doesn't match, reshape or pad
if state.size(2) > self.state_size:
state = state[:, :, :self.state_size]
else:
state = F.pad(state, (0, self.state_size - state.size(2)))
# Concatenate along sequence dimension
combined_input = torch.cat([state, cnn_features], dim=1)
else:
# Use only state input if CNN features not provided
combined_input = state
if combined_input.dim() < 3:
# If state is 2D [batch, features], reshape to 3D [batch, 1, features]
combined_input = combined_input.unsqueeze(1)
# Ensure state has proper dimensions
if combined_input.size(2) != self.state_size:
# If state dimension doesn't match, reshape or pad
if combined_input.size(2) > self.state_size:
combined_input = combined_input[:, :, :self.state_size]
else:
combined_input = F.pad(combined_input, (0, self.state_size - combined_input.size(2)))
# Pass through LSTM
lstm_out, _ = self.lstm(combined_input)
# Apply self-attention to LSTM output
# Transform to shape required by MultiheadAttention (seq_len, batch, hidden)
attn_input = lstm_out.transpose(0, 1)
attn_output, _ = self.attention(attn_input, attn_input, attn_input)
# Transform back to (batch, seq_len, hidden)
attn_output = attn_output.transpose(0, 1)
# Use last output after attention
attn_out = attn_output[:, -1]
# Value and advantage streams (dueling architecture)
value = self.value_stream(attn_out)
advantage = self.advantage_stream(attn_out)
# Combine value and advantage for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values
def forward_realtime(self, x):
"""Simplified forward pass for realtime inference"""
# Adapt x to the right format if needed
if isinstance(x, np.ndarray):
x = torch.FloatTensor(x)
# Add batch dimension if not present
if x.dim() == 1:
x = x.unsqueeze(0)
# Add sequence dimension if not present
if x.dim() == 2:
x = x.unsqueeze(1)
# Basic forward pass
lstm_out, _ = self.lstm(x)
# Apply attention
attn_input = lstm_out.transpose(0, 1)
attn_output, _ = self.attention(attn_input, attn_input, attn_input)
attn_output = attn_output.transpose(0, 1)
# Get last output after attention
features = attn_output[:, -1]
# Value and advantage streams
value = self.value_stream(features)
advantage = self.advantage_stream(features)
# Combine for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values
# Add this class after the CandleCache class
class RealTimeDataStream:
"""
Class for handling WebSocket API connections for ultra-low latency trading signals.
Provides real-time data streaming at 1-second intervals or faster for immediate trading decisions.
"""
def __init__(self, exchange, symbol, callback_fn=None):
"""
Initialize the real-time data stream with WebSocket connection
Args:
exchange: The exchange API client
symbol: Trading pair symbol (e.g. 'ETH/USDT')
callback_fn: Function to call when new data is received
"""
self.exchange = exchange
self.symbol = symbol
self.callback_fn = callback_fn
self.websocket = None
self.connected = False
self.last_tick_time = None
self.tick_buffer = []
self.latency_stats = []
self.logger = logging.getLogger(__name__)
# Statistics for monitoring performance
self.total_ticks = 0
self.avg_latency_ms = 0
self.max_latency_ms = 0
# Candle cache for storing processed data
self.candle_cache = CandleCache()
async def connect(self):
"""Connect to the exchange WebSocket API"""
# TODO: Implement actual WebSocket connection logic
self.logger.info(f"Connecting to WebSocket for {self.symbol}...")
try:
# This will be replaced with actual WebSocket connection code
self.websocket = None # Placeholder
self.connected = True
self.logger.info(f"Connected to WebSocket for {self.symbol}")
return True
except Exception as e:
self.logger.error(f"WebSocket connection error: {e}")
return False
async def subscribe(self):
"""Subscribe to relevant data channels"""
# TODO: Implement actual WebSocket subscription logic
self.logger.info(f"Subscribing to {self.symbol} ticks...")
try:
# This will be replaced with actual subscription code
return True
except Exception as e:
self.logger.error(f"WebSocket subscription error: {e}")
return False
async def process_message(self, message):
"""
Process incoming WebSocket message
Args:
message: The raw WebSocket message
Returns:
Processed tick data
"""
# TODO: Implement actual WebSocket message processing logic
try:
# Track tick receipt time for latency calculations
receive_time = time.time() * 1000 # milliseconds
# This is a placeholder - actual implementation will parse the message
# Example tick data structure (will vary by exchange):
tick_data = {
'timestamp': receive_time,
'price': 0.0, # Will be replaced with actual price
'volume': 0.0, # Will be replaced with actual volume
'side': 'buy', # or 'sell'
'exchange_time': 0, # Will be replaced with exchange timestamp
'latency_ms': 0 # Will be calculated
}
# Calculate latency (difference between our receive time and exchange time)
if 'exchange_time' in tick_data and tick_data['exchange_time'] > 0:
latency = receive_time - tick_data['exchange_time']
tick_data['latency_ms'] = latency
# Update latency statistics
self.latency_stats.append(latency)
if len(self.latency_stats) > 1000:
self.latency_stats = self.latency_stats[-1000:]
self.total_ticks += 1
self.avg_latency_ms = sum(self.latency_stats) / len(self.latency_stats)
self.max_latency_ms = max(self.max_latency_ms, latency)
# Store tick in buffer
self.tick_buffer.append(tick_data)
self.candle_cache.add_tick(tick_data)
self.last_tick_time = datetime.datetime.now()
# Keep buffer size reasonable
if len(self.tick_buffer) > 1000:
self.tick_buffer = self.tick_buffer[-1000:]
# Call callback function if provided
if self.callback_fn:
await self.callback_fn(tick_data)
return tick_data
except Exception as e:
self.logger.error(f"Error processing WebSocket message: {e}")
return None
def prepare_nn_input(self, model=None, state=None):
"""
Prepare network inputs from tick data for real-time inference
Args:
model: The neural network model
state: Current state representation
Returns:
Prepared tensors for model input
"""
# Get the most recent ticks
ticks = self.candle_cache.get_ticks(limit=300)
if not ticks or len(ticks) < 10:
# Not enough ticks for meaningful processing
return None
try:
# Extract price and volume data from ticks
prices = np.array([t['price'] for t in ticks if 'price' in t])
volumes = np.array([t['volume'] for t in ticks if 'volume' in t])
if len(prices) < 10:
return None
# Normalize data
min_price, max_price = prices.min(), prices.max()
price_range = max_price - min_price
if price_range == 0:
price_range = 1
normalized_prices = (prices - min_price) / price_range
# Create tick tensor - this is flexible-length data
# Format as sequence for time-series analysis
tick_data = torch.FloatTensor(normalized_prices).unsqueeze(0).unsqueeze(0)
return {
'state': state,
'ticks': tick_data
}
except Exception as e:
self.logger.error(f"Error preparing neural network input: {e}")
return None
def get_latency_stats(self):
"""Get statistics about WebSocket connection latency"""
return {
'total_ticks': self.total_ticks,
'avg_latency_ms': self.avg_latency_ms,
'max_latency_ms': self.max_latency_ms,
'last_update': self.last_tick_time.isoformat() if self.last_tick_time else None
}
async def close(self):
"""Close the WebSocket connection"""
if self.connected and self.websocket:
try:
# This will be replaced with actual close logic
self.connected = False
self.logger.info(f"Closed WebSocket connection for {self.symbol}")
return True
except Exception as e:
self.logger.error(f"Error closing WebSocket connection: {e}")
return False
class BacktestCandles(CandleCache):
"""
Special cache for backtesting that retrieves historical data from specific time periods
without contaminating the main cache. Used for running simulations "as if" we were
at a different point in time.
"""
def __init__(self, since_timestamp=None, until_timestamp=None):
"""
Initialize backtesting candle cache.
Args:
since_timestamp: Start timestamp for backtesting (milliseconds)
until_timestamp: End timestamp for backtesting (milliseconds)
"""
super().__init__()
# Since and until timestamps for backtesting
self.since_timestamp = since_timestamp
self.until_timestamp = until_timestamp
# Flag to indicate this is a backtesting cache
self.is_backtesting = True
# Optional name for backtesting period (e.g., "Day 1 - 24h ago")
self.period_name = None
async def fetch_historical_timeframe(self, exchange, symbol, timeframe, limit=1000):
"""
Fetch historical data for a specific timeframe and time period.
Args:
exchange: The exchange instance
symbol: Trading pair symbol
timeframe: Candle timeframe
limit: Number of candles to fetch
Returns:
Dictionary with candle data for the timeframe
"""
try:
logging.info(f"Fetching historical {timeframe} candles for {symbol} " +
f"(since: {self.format_timestamp(self.since_timestamp) if self.since_timestamp else 'None'}, " +
f"until: {self.format_timestamp(self.until_timestamp) if self.until_timestamp else 'None'})")
candles = await self.fetch_ohlcv_with_timerange(exchange, symbol, timeframe,
limit, self.since_timestamp, self.until_timestamp)
if candles:
# Store in the appropriate timeframe
self.candles[timeframe] = candles
self.last_updated[timeframe] = datetime.datetime.now()
logging.info(f"Fetched {len(candles)} historical {timeframe} candles for backtesting")
else:
logging.warning(f"No historical {timeframe} candles found for the specified time period")
return candles
except Exception as e:
logging.error(f"Error fetching historical {timeframe} data: {e}")
return []
async def fetch_all_timeframes(self, exchange, symbol):
"""
Fetch historical data for all timeframes.
Args:
exchange: The exchange instance
symbol: Trading pair symbol
Returns:
Dictionary with candle data for all timeframes
"""
# Define limits for each timeframe
limits = {
'1m': 1000,
'1h': 500,
'1d': 300
}
# Fetch data for each timeframe
for timeframe, limit in limits.items():
await self.fetch_historical_timeframe(exchange, symbol, timeframe, limit)
# Return the candles dictionary
return {
'1m': self.get_candles('1m'),
'1h': self.get_candles('1h'),
'1d': self.get_candles('1d')
}
async def fetch_ohlcv_with_timerange(self, exchange, symbol, timeframe, limit, since=None, until=None):
"""
Fetch OHLCV data within a specific time range.
Args:
exchange: The exchange instance
symbol: Trading pair symbol
timeframe: Candle timeframe
limit: Number of candles to fetch
since: Start timestamp (milliseconds)
until: End timestamp (milliseconds)
Returns:
List of candle data
"""
max_retries = 3
retry_delay = 5
for attempt in range(max_retries):
try:
logging.info(f"Fetching {limit} {timeframe} candles for {symbol} " +
f"(since: {self.format_timestamp(since) if since else 'None'}, " +
f"until: {self.format_timestamp(until) if until else 'None'}) " +
f"(attempt {attempt+1}/{max_retries})")
# Check if exchange has fetch_ohlcv method
if not hasattr(exchange, 'fetch_ohlcv'):
logging.error("Exchange does not support OHLCV data fetching")
return []
# Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor
try:
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, since=since, limit=limit)
else:
# Run in executor to avoid blocking
loop = asyncio.get_event_loop()
ohlcv = await loop.run_in_executor(
None,
lambda: exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
)
except Exception as e:
logging.error(f"Failed to fetch OHLCV data: {e}")
await asyncio.sleep(retry_delay)
continue
if not ohlcv or len(ohlcv) == 0:
logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})")
await asyncio.sleep(retry_delay)
continue
# Filter candles if until timestamp is provided
if until is not None:
ohlcv = [candle for candle in ohlcv if candle[0] <= until]
# Convert to list of lists format
data = []
for candle in ohlcv:
timestamp, open_price, high, low, close, volume = candle
data.append([timestamp, open_price, high, low, close, volume])
logging.info(f"Successfully fetched {len(data)} historical candles")
return data
except Exception as e:
logging.error(f"Error fetching historical OHLCV data (attempt {attempt+1}/{max_retries}): {e}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
logging.error(f"Failed to fetch historical OHLCV data after {max_retries} attempts")
return []
def format_timestamp(self, timestamp):
"""Format a timestamp for readable logging"""
if timestamp is None:
return "None"
try:
dt = datetime.datetime.fromtimestamp(timestamp / 1000.0)
return dt.strftime('%Y-%m-%d %H:%M:%S')
except:
return str(timestamp)
async def train_with_backtesting(agent, env, symbol="ETH/USDT",
since_timestamp=None, until_timestamp=None,
num_episodes=10, max_steps_per_episode=1000,
period_name=None):
"""
Train agent with backtesting on historical data.
Args:
agent: The agent to train
env: Trading environment
symbol: Trading pair symbol
since_timestamp: Start timestamp for backtesting
until_timestamp: End timestamp for backtesting
num_episodes: Number of episodes to train
max_steps_per_episode: Maximum steps per episode
period_name: Name of the backtest period
Returns:
Training statistics dictionary
"""
# Create a backtesting candle cache
backtest_cache = BacktestCandles(since_timestamp, until_timestamp)
if period_name:
backtest_cache.period_name = period_name
logging.info(f"Starting backtesting for period: {period_name}")
# Initialize exchange for data fetching
exchange = None
try:
exchange = await initialize_exchange()
logging.info("Initialized exchange for backtesting")
except Exception as e:
logging.error(f"Failed to initialize exchange: {e}")
return None
# Initialize statistics tracking
stats = {
'period': period_name,
'since_timestamp': since_timestamp,
'until_timestamp': until_timestamp,
'episode_rewards': [],
'episode_lengths': [],
'balances': [],
'win_rates': [],
'episode_pnls': [],
'cumulative_pnl': [],
'drawdowns': [],
'trade_counts': [],
'loss_values': [],
'fees': [],
'net_pnl_after_fees': []
}
# Memory management function
def clean_memory():
"""Clean up memory to avoid memory leaks"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Fetch historical data for all timeframes
try:
clean_memory() # Clean memory before fetching data
candle_data = await backtest_cache.fetch_all_timeframes(exchange, symbol)
if not candle_data or not candle_data['1m']:
logging.error(f"No historical data available for backtesting period: {period_name}")
try:
await exchange.close()
except Exception as e:
logging.error(f"Error closing exchange: {e}")
return None
logging.info(f"Fetched historical data for backtesting: {len(candle_data['1m'])} minute candles")
except Exception as e:
logging.error(f"Failed to fetch historical data for backtesting: {e}")
try:
await exchange.close()
except Exception as exchange_err:
logging.error(f"Error closing exchange: {exchange_err}")
return None
# Track best models
best_reward = float('-inf')
best_pnl = float('-inf')
best_net_pnl = float('-inf')
# Make directory for backtesting models if it doesn't exist
os.makedirs('models/backtest', exist_ok=True)
# Start backtesting training loop
for episode in range(num_episodes):
try:
# Clean memory before starting a new episode
clean_memory()
# Reset environment
state = env.reset()
episode_reward = 0
episode_losses = []
# Update CNN patterns with historical data
env.update_cnn_patterns(candle_data)
# Track consecutive errors for circuit breaker
consecutive_errors = 0
max_consecutive_errors = 5
# Episode loop
for step in range(max_steps_per_episode):
try:
# Select action using CNN-enhanced policy
action = agent.select_action(state, training=True, candle_data=candle_data)
# Take action
next_state, reward, done, info = env.step(action)
# Store transition in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
# Update episode reward
episode_reward += reward
# Learn from experience
if len(agent.memory) > BATCH_SIZE:
try:
loss = agent.learn()
if loss is not None:
episode_losses.append(loss)
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logging.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update target network periodically
if step % TARGET_UPDATE == 0:
agent.update_target_network()
# Clean memory periodically during long episodes
if step % 200 == 0 and step > 0:
clean_memory()
# End episode if done
if done:
break
except Exception as e:
logging.error(f"Error in training step: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Calculate statistics
mean_loss = np.mean(episode_losses) if episode_losses else 0
balance = env.balance
pnl = balance - env.initial_balance
fees = env.total_fees
net_pnl = pnl - fees
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
trade_count = env.trade_count if hasattr(env, 'trade_count') else 0
# Update epsilon for exploration
epsilon = agent.update_epsilon(episode)
# Update statistics
stats['episode_rewards'].append(episode_reward)
stats['episode_lengths'].append(step + 1)
stats['balances'].append(balance)
stats['win_rates'].append(win_rate)
stats['episode_pnls'].append(pnl)
stats['drawdowns'].append(env.max_drawdown)
stats['trade_counts'].append(trade_count)
stats['loss_values'].append(mean_loss)
stats['fees'].append(fees)
stats['net_pnl_after_fees'].append(net_pnl)
# Calculate and update cumulative PnL
if len(stats['episode_pnls']) > 0:
cumulative_pnl = sum(stats['episode_pnls'])
if 'cumulative_pnl' not in stats:
stats['cumulative_pnl'] = []
stats['cumulative_pnl'].append(cumulative_pnl)
if writer:
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
# Save model if this is the best reward or PnL
if episode_reward > best_reward:
best_reward = episode_reward
model_path = f"models/backtest/{period_name}_best_reward.pt" if period_name else "models/backtest/best_reward.pt"
try:
agent.save(model_path)
logging.info(f"New best reward: {best_reward:.2f}")
except Exception as e:
logging.error(f"Error saving best reward model: {e}")
logging.info(f"New best reward: {best_reward:.2f} (model not saved)")
if pnl > best_pnl:
best_pnl = pnl
model_path = f"models/backtest/{period_name}_best_pnl.pt" if period_name else "models/backtest/best_pnl.pt"
try:
agent.save(model_path)
logging.info(f"New best PnL: ${best_pnl:.2f}")
except Exception as e:
logging.error(f"Error saving best PnL model: {e}")
logging.info(f"New best PnL: ${best_pnl:.2f} (model not saved)")
# Save model if this is the best net PnL (after fees)
if net_pnl > best_net_pnl:
best_net_pnl = net_pnl
model_path = f"models/backtest/{period_name}_best_net_pnl.pt" if period_name else "models/backtest/best_net_pnl.pt"
try:
agent.save(model_path)
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
except Exception as e:
logging.error(f"Error saving best net PnL model: {e}")
logging.info(f"New best Net PnL: ${best_net_pnl:.2f} (model not saved)")
# Save checkpoint periodically
if episode % 10 == 0:
try:
if use_compact_save:
compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt')
else:
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
except Exception as e:
logging.error(f"Error saving checkpoint model: {e}")
# Update epsilon
agent.update_epsilon(episode)
# Log training progress
logging.info(f"Episode {episode+1}/{num_episodes} | " +
f"Reward: {episode_reward:.2f} | " +
f"Balance: ${balance:.2f} | " +
f"PnL: ${pnl:.2f} | " +
f"Fees: ${fees:.2f} | " +
f"Net PnL: ${net_pnl:.2f} | " +
f"Win Rate: {win_rate:.2f} | " +
f"Trades: {trade_count} | " +
f"Loss: {mean_loss:.5f} | " +
f"Epsilon: {agent.epsilon:.4f}")
except Exception as e:
logging.error(f"Error in episode {episode}: {e}")
logging.error(traceback.format_exc())
continue
# Clean memory before saving final model
clean_memory()
# Save final model
if period_name:
try:
agent.save(f"models/backtest/{period_name}_final.pt")
logging.info(f"Saved final model for period: {period_name}")
except Exception as e:
logging.error(f"Error saving final model: {e}")
# Save backtesting statistics
stats_file = f"backtest_stats_{period_name}.csv" if period_name else "backtest_stats.csv"
try:
with open(stats_file, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['Episode', 'Reward', 'Balance', 'PnL', 'Fees', 'Net PnL', 'Win Rate', 'Trades', 'Loss'])
for i in range(len(stats['episode_rewards'])):
writer.writerow([
i+1,
stats['episode_rewards'][i],
stats['balances'][i],
stats['episode_pnls'][i],
stats['fees'][i],
stats['net_pnl_after_fees'][i],
stats['win_rates'][i],
stats['trade_counts'][i],
stats['loss_values'][i]
])
logging.info(f"Backtesting statistics saved to {stats_file}")
except Exception as e:
logging.error(f"Error saving backtesting statistics: {e}")
# Close exchange connection
if exchange:
try:
await exchange.close()
logging.info("Exchange connection closed successfully")
except AttributeError:
# Some exchanges don't have a close method
logging.info("Exchange doesn't have a close method, skipping")
except Exception as e:
logging.error(f"Error closing exchange connection: {e}")
return stats
# Implement a robust save function to handle PyTorch serialization errors
def robust_save(model, path):
"""
Save a model with multiple fallback approaches to ensure file is saved
even in low disk space conditions.
"""
logger.info(f"Saving model to {path}.backup (attempt 1)")
backup_path = f"{path}.backup"
# Attempt 1: Regular save to backup file
try:
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If successful, copy to final path
try:
shutil.copy2(backup_path, path)
logger.info(f"Copied backup to {path}")
logger.info(f"Model saved successfully to {path}")
return True
except Exception as e:
logger.warning(f"Failed to copy backup to main file: {str(e)}")
logger.info(f"Using backup file as the main save")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {str(e)}")
# Attempt 2: Try with older pickle protocol
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
try:
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with protocol 2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {str(e)}")
# Attempt 3: Try without optimizer
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
try:
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logger.info(f"Successfully saved to {path} without optimizer")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {str(e)}")
# Attempt 4: Save model structure (as JSON) and parameters separately
logger.info(f"Saving model to {path} (attempt 4 - model structure as JSON)")
try:
# Save only essential model parameters as JSON
model_params = {
'epsilon': float(model.epsilon),
'state_size': model.state_size,
'action_size': model.action_size,
'hidden_size': model.hidden_size,
'lstm_layers': model.policy_net.lstm_layers if hasattr(model.policy_net, 'lstm_layers') else 2,
'attention_heads': model.policy_net.attention_heads if hasattr(model.policy_net, 'attention_heads') else 4
}
params_path = f"{path}.params.json"
with open(params_path, 'w') as f:
json.dump(model_params, f)
logger.info(f"Successfully saved model parameters to {params_path}")
# Now try to save a smaller version of the model without CNN components
# This is a more minimal save for recovery purposes
try:
# Create stripped down checkpoint with minimal components
minimal_checkpoint = {
'epsilon': model.epsilon,
'state_size': model.state_size,
'action_size': model.action_size,
'hidden_size': model.hidden_size
}
minimal_path = f"{path}.minimal"
torch.save(minimal_checkpoint, minimal_path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logger.info(f"Successfully saved minimal checkpoint to {minimal_path}")
except Exception as e:
logger.warning(f"Minimal checkpoint save failed: {str(e)}")
logger.info(f"Model saved successfully to {path}")
return True
except Exception as e:
logger.error(f"All save attempts failed for {path}: {str(e)}")
return False
def cleanup_model_files(keep_best=True, keep_latest_n=5, aggressive=False):
"""
Delete old model files to free up disk space.
Args:
keep_best (bool): Whether to keep the best model files (reward, pnl, net_pnl)
keep_latest_n (int): Number of latest checkpoint files to keep
aggressive (bool): If True, apply more aggressive cleanup in very low disk scenarios
"""
try:
logging.info(f"Running model file cleanup: keep_best={keep_best}, keep_latest_n={keep_latest_n}, aggressive={aggressive}")
models_dir = "models"
# Get all files in the models directory
all_files = os.listdir(models_dir)
# Files to potentially delete
checkpoint_files = []
backup_files = []
params_files = []
dated_files = []
# Best files to keep if keep_best is True
best_patterns = [
"trading_agent_best_reward.pt",
"trading_agent_best_pnl.pt",
"trading_agent_best_net_pnl.pt",
"trading_agent_final.pt"
]
# Categorize files for potential deletion
for filename in all_files:
file_path = os.path.join(models_dir, filename)
# Skip directories
if os.path.isdir(file_path):
continue
# Skip current best files if keep_best is True
if keep_best and any(filename == pattern for pattern in best_patterns):
continue
# Check for different file types
if "checkpoint" in filename and filename.endswith(".pt"):
checkpoint_files.append((filename, os.path.getmtime(file_path), file_path))
elif filename.endswith(".backup"):
backup_files.append((filename, os.path.getmtime(file_path), file_path))
elif filename.endswith(".params.json"):
params_files.append((filename, os.path.getmtime(file_path), file_path))
elif "_2025" in filename or "_2024" in filename: # Files with date stamps
dated_files.append((filename, os.path.getmtime(file_path), file_path))
bytes_freed = 0
files_deleted = 0
# Process checkpoint files - keep the newest N
if len(checkpoint_files) > keep_latest_n:
# Sort by modification time (newest first)
checkpoint_files.sort(key=lambda x: x[1], reverse=True)
# Keep the newest N files
files_to_delete = checkpoint_files[keep_latest_n:]
# Delete old checkpoint files
for _, _, file_path in files_to_delete:
try:
file_size = os.path.getsize(file_path)
os.remove(file_path)
bytes_freed += file_size
files_deleted += 1
logging.info(f"Deleted old checkpoint file: {file_path}")
except Exception as e:
logging.error(f"Failed to delete file {file_path}: {str(e)}")
# If aggressive cleanup is enabled, remove more files
if aggressive:
# Delete all backup files except the newest one
if backup_files:
backup_files.sort(key=lambda x: x[1], reverse=True)
for _, _, file_path in backup_files[1:]: # Keep only newest backup
try:
file_size = os.path.getsize(file_path)
os.remove(file_path)
bytes_freed += file_size
files_deleted += 1
logging.info(f"Deleted old backup file: {file_path}")
except Exception as e:
logging.error(f"Failed to delete file {file_path}: {str(e)}")
# Delete all dated files (these are typically archived models)
for _, _, file_path in dated_files:
try:
file_size = os.path.getsize(file_path)
os.remove(file_path)
bytes_freed += file_size
files_deleted += 1
logging.info(f"Deleted dated model file: {file_path}")
except Exception as e:
logging.error(f"Failed to delete file {file_path}: {str(e)}")
logging.info(f"Cleanup complete. Deleted {files_deleted} files, freed {bytes_freed / (1024*1024):.2f} MB")
# Check available disk space after cleanup
try:
if platform.system() == 'Windows':
free_bytes = ctypes.c_ulonglong(0)
ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(os.path.abspath(models_dir)), None, None, ctypes.pointer(free_bytes))
free_mb = free_bytes.value / (1024 * 1024)
else:
st = os.statvfs(os.path.abspath(models_dir))
free_mb = (st.f_bavail * st.f_frsize) / (1024 * 1024)
logging.info(f"Available disk space after cleanup: {free_mb:.2f} MB")
# If space is still low, recommend aggressive cleanup
if free_mb < 200 and not aggressive: # Less than 200MB available
logging.warning("Disk space still critically low. Consider using aggressive cleanup.")
except Exception as e:
logging.error(f"Error checking disk space: {str(e)}")
except Exception as e:
logging.error(f"Error during file cleanup: {str(e)}")
logging.error(traceback.format_exc())
def compact_save(model, optimizer, reward, epsilon, state_size, action_size, hidden_size, path, use_quantization=False):
"""
Save a model in a compact format suitable for low disk space environments.
Includes fallbacks if the primary save method fails.
Args:
model: The model to save
optimizer: The optimizer to save
reward: The current reward
epsilon: The current epsilon value
state_size: The state size
action_size: The action size
hidden_size: The hidden size
path: The path to save to
use_quantization: Whether to use quantization to reduce model size
Returns:
bool: Whether the save was successful
"""
try:
# Create minimal checkpoint with essential data only
checkpoint = {
'model_state_dict': model.state_dict(),
'epsilon': epsilon,
'state_size': state_size,
'action_size': action_size,
'hidden_size': hidden_size
}
# Apply quantization if requested
if use_quantization:
try:
logging.info(f"Attempting quantized save to {path}")
# Quantize model to int8
quantized_model = torch.quantization.quantize_dynamic(
model, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8 # the target dtype for quantized weights
)
# Create quantized checkpoint
quantized_checkpoint = {
'model_state_dict': quantized_model.state_dict(),
'epsilon': epsilon,
'state_size': state_size,
'action_size': action_size,
'hidden_size': hidden_size,
'is_quantized': True
}
# Save with older pickle protocol and disable new zipfile serialization
torch.save(quantized_checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logging.info(f"Quantized compact save successful to {path}")
return True
except Exception as e:
logging.warning(f"Quantized save failed, falling back to regular save: {str(e)}")
# Fall back to regular save if quantization fails
# Regular save with older pickle protocol and no zipfile serialization
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
logging.info(f"Compact save successful to {path}")
return True
except Exception as e:
logging.error(f"Compact save failed: {str(e)}")
logging.error(traceback.format_exc())
# Fallback: Save just the parameters as JSON if we can't save the full model
try:
params = {
'epsilon': epsilon,
'state_size': state_size,
'action_size': action_size,
'hidden_size': hidden_size
}
json_path = f"{path}.params.json"
with open(json_path, 'w') as f:
json.dump(params, f)
logging.info(f"Saved minimal parameters to {json_path}")
return False
except Exception as json_e:
logging.error(f"JSON parameter save failed: {str(json_e)}")
return False
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description='Trading Bot')
parser.add_argument('--mode', type=str, default='train', help='Mode: train, test, live')
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
parser.add_argument('--max_steps', type=int, default=1000, help='Maximum steps per episode')
parser.add_argument('--update_interval', type=int, default=10, help='Target network update interval')
parser.add_argument('--training_iterations', type=int, default=10, help='Number of training iterations per step')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for candlestick data')
parser.add_argument('--compact_save', action='store_true', help='Use compact save to reduce disk usage')
parser.add_argument('--use_quantization', action='store_true', help='Use model quantization for even smaller file sizes')
parser.add_argument('--cleanup', action='store_true', help='Clean up old model files before training')
parser.add_argument('--aggressive_cleanup', action='store_true', help='Perform aggressive cleanup to free more space')
parser.add_argument('--keep_latest', type=int, default=5, help='Number of latest checkpoint files to keep when cleaning up')
args = parser.parse_args()
# Import platform and ctypes for disk space checking
import platform
import ctypes
# Run cleanup if requested
if args.cleanup:
cleanup_model_files(keep_best=True, keep_latest_n=args.keep_latest, aggressive=args.aggressive_cleanup)
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program terminated by user")