5166 lines
215 KiB
Python
5166 lines
215 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
import random
|
|
import logging
|
|
import asyncio
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from collections import deque, namedtuple
|
|
from dotenv import load_dotenv
|
|
import ccxt
|
|
import websockets
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torch.cuda.amp as amp # Add this import at the top
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
import copy
|
|
import argparse
|
|
import traceback
|
|
import io
|
|
import matplotlib.dates as mdates
|
|
from matplotlib.figure import Figure
|
|
from PIL import Image
|
|
import matplotlib.pyplot as mpf
|
|
import matplotlib.gridspec as gridspec
|
|
import datetime
|
|
from datetime import datetime as dt
|
|
from collections import defaultdict
|
|
from gym.spaces import Discrete, Box
|
|
import csv
|
|
import gc
|
|
import shutil
|
|
import math
|
|
import platform
|
|
import ctypes
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
|
|
)
|
|
logger = logging.getLogger("trading_bot")
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
|
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
|
|
|
# Constants
|
|
INITIAL_BALANCE = 100 # USD
|
|
MAX_LEVERAGE = 100
|
|
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
|
|
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
|
|
MEMORY_SIZE = 100000
|
|
BATCH_SIZE = 64
|
|
GAMMA = 0.99 # Discount factor
|
|
EPSILON_START = 1.0
|
|
EPSILON_END = 0.05
|
|
EPSILON_DECAY = 10000
|
|
STATE_SIZE = 64 # Size of our state representation
|
|
LEARNING_RATE = 1e-4
|
|
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
|
|
|
# Experience replay tuple
|
|
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
|
|
|
|
# Add this function near the top of the file, after the imports but before any classes
|
|
def find_local_extrema(prices, window=5):
|
|
"""
|
|
Find local extrema (tops and bottoms) in price series.
|
|
|
|
Args:
|
|
prices: Array of price values
|
|
window: Window size for finding extrema
|
|
|
|
Returns:
|
|
Tuple of (tops, bottoms) indices
|
|
"""
|
|
tops = []
|
|
bottoms = []
|
|
|
|
if len(prices) < window * 2 + 1:
|
|
return tops, bottoms
|
|
|
|
try:
|
|
# Use peak detection algorithms from scipy if available
|
|
from scipy.signal import find_peaks
|
|
|
|
# Find peaks (tops)
|
|
peaks, _ = find_peaks(prices, distance=window)
|
|
tops = list(peaks)
|
|
|
|
# Find valleys (bottoms) by inverting the prices
|
|
valleys, _ = find_peaks(-prices, distance=window)
|
|
bottoms = list(valleys)
|
|
|
|
# Optional: Filter extrema for significance
|
|
if len(tops) > 0 and len(bottoms) > 0:
|
|
# Calculate average price move
|
|
avg_move = np.mean(np.abs(np.diff(prices)))
|
|
|
|
# Filter tops and bottoms for significant moves
|
|
filtered_tops = []
|
|
for top in tops:
|
|
# Check if this top is significantly higher than surrounding points
|
|
if top > window and top < len(prices) - window:
|
|
surrounding_min = min(prices[top-window:top+window])
|
|
if prices[top] - surrounding_min > avg_move * 1.5: # 1.5x average move
|
|
filtered_tops.append(top)
|
|
|
|
filtered_bottoms = []
|
|
for bottom in bottoms:
|
|
# Check if this bottom is significantly lower than surrounding points
|
|
if bottom > window and bottom < len(prices) - window:
|
|
surrounding_max = max(prices[bottom-window:bottom+window])
|
|
if surrounding_max - prices[bottom] > avg_move * 1.5: # 1.5x average move
|
|
filtered_bottoms.append(bottom)
|
|
|
|
tops = filtered_tops
|
|
bottoms = filtered_bottoms
|
|
|
|
except ImportError:
|
|
# Fallback to manual detection if scipy is not available
|
|
for i in range(window, len(prices) - window):
|
|
# Check if this point is a local maximum
|
|
if all(prices[i] >= prices[i - j] for j in range(1, window + 1)) and \
|
|
all(prices[i] >= prices[i + j] for j in range(1, window + 1)):
|
|
tops.append(i)
|
|
|
|
# Check if this point is a local minimum
|
|
if all(prices[i] <= prices[i - j] for j in range(1, window + 1)) and \
|
|
all(prices[i] <= prices[i + j] for j in range(1, window + 1)):
|
|
bottoms.append(i)
|
|
|
|
return tops, bottoms
|
|
|
|
class ReplayMemory:
|
|
def __init__(self, capacity):
|
|
self.memory = deque(maxlen=capacity)
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
self.memory.append(Experience(state, action, reward, next_state, done))
|
|
|
|
def sample(self, batch_size):
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
class DQN(nn.Module):
|
|
"""Deep Q-Network with enhanced architecture"""
|
|
|
|
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
|
super(DQN, self).__init__()
|
|
self.network = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads)
|
|
self.hidden_size = hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
self.attention_heads = attention_heads
|
|
|
|
def forward(self, state, x_1s=None, x_1m=None, x_1h=None, x_1d=None):
|
|
# Pass through to LSTMAttentionDQN
|
|
if x_1m is not None and x_1h is not None and x_1d is not None:
|
|
return self.network(state, x_1s, x_1m, x_1h, x_1d)
|
|
else:
|
|
return self.network(state)
|
|
|
|
class PricePredictionModel(nn.Module):
|
|
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
|
|
super(PricePredictionModel, self).__init__()
|
|
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
|
self.scaler = MinMaxScaler(feature_range=(0, 1))
|
|
self.is_fitted = False
|
|
|
|
def forward(self, x):
|
|
# x shape: [batch_size, seq_len, 1]
|
|
lstm_out, _ = self.lstm(x)
|
|
# Use the last time step output
|
|
predictions = self.fc(lstm_out[:, -1, :])
|
|
return predictions
|
|
|
|
def preprocess(self, data):
|
|
# Reshape data for scaler
|
|
data_reshaped = np.array(data).reshape(-1, 1)
|
|
|
|
# Fit scaler if not already fitted
|
|
if not self.is_fitted:
|
|
self.scaler.fit(data_reshaped)
|
|
self.is_fitted = True
|
|
|
|
# Transform data
|
|
scaled_data = self.scaler.transform(data_reshaped)
|
|
return scaled_data
|
|
|
|
def postprocess(self, scaled_predictions):
|
|
# Inverse transform to get actual price values
|
|
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
|
|
|
|
def predict_next_candles(self, price_history, num_candles=5):
|
|
if len(price_history) < 30: # Need enough history
|
|
return np.zeros(num_candles)
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequence
|
|
sequence = scaled_data[-30:].reshape(1, 30, 1)
|
|
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
|
|
|
|
# Get predictions
|
|
with torch.no_grad():
|
|
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
|
|
|
|
# Postprocess predictions
|
|
predictions = self.postprocess(scaled_predictions)
|
|
return predictions
|
|
|
|
def train_on_new_data(self, price_history, optimizer, epochs=10):
|
|
if len(price_history) < 35: # Need enough history for training
|
|
return 0.0
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequences and targets
|
|
sequences = []
|
|
targets = []
|
|
|
|
for i in range(len(scaled_data) - 35):
|
|
# Sequence: 30 time steps
|
|
seq = scaled_data[i:i+30]
|
|
# Target: next 5 time steps
|
|
target = scaled_data[i+30:i+35].flatten()
|
|
|
|
sequences.append(seq)
|
|
targets.append(target)
|
|
|
|
if not sequences: # If no sequences were created
|
|
return 0.0
|
|
|
|
# Convert to tensors
|
|
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
|
|
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
|
|
|
|
# Training loop
|
|
total_loss = 0
|
|
for _ in range(epochs):
|
|
# Forward pass
|
|
predictions = self(sequences_tensor)
|
|
|
|
# Calculate loss
|
|
loss = F.mse_loss(predictions, targets_tensor)
|
|
|
|
# Backward pass and optimize
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
return total_loss / epochs
|
|
|
|
class TradingEnvironment:
|
|
def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50,
|
|
window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance',
|
|
symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000):
|
|
"""Initialize the trading environment"""
|
|
self.api_key = api_key
|
|
self.api_secret = api_secret
|
|
self.exchange_id = exchange_id
|
|
self.symbol = symbol
|
|
self.timeframe = timeframe
|
|
self.init_length = init_length
|
|
|
|
# TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data
|
|
|
|
try:
|
|
# Initialize exchange if API credentials are provided
|
|
if api_key and api_secret:
|
|
self.exchange = initialize_exchange(exchange_id, api_key, api_secret)
|
|
logger.info(f"Exchange initialized: {exchange_id}")
|
|
# Fetch historical data
|
|
self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length)
|
|
if not self.data:
|
|
raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}")
|
|
self.data_format_is_list = isinstance(self.data[0], list)
|
|
logger.info(f"Loaded {len(self.data)} candles from exchange")
|
|
elif data is not None: # Use provided data
|
|
self.data = data
|
|
self.data_format_is_list = isinstance(self.data[0], list)
|
|
logger.info(f"Using provided data with {len(self.data)} candles")
|
|
else:
|
|
# Initialize with empty data, we'll load it later with fetch_initial_data
|
|
logger.warning("No data provided, initializing with empty data")
|
|
self.data = []
|
|
self.data_format_is_list = True
|
|
except Exception as e:
|
|
logger.error(f"Error initializing environment: {e}")
|
|
raise
|
|
|
|
# Initialize features and feature extractors
|
|
if features is not None:
|
|
self.features = features
|
|
# Create a dictionary of features
|
|
self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)}
|
|
else:
|
|
# Initialize features as a dictionary, not a list
|
|
self.features = {
|
|
'price': [],
|
|
'volume': [],
|
|
'rsi': [],
|
|
'macd': [],
|
|
'macd_signal': [],
|
|
'macd_hist': [],
|
|
'bollinger_upper': [],
|
|
'bollinger_mid': [],
|
|
'bollinger_lower': [],
|
|
'stoch_k': [],
|
|
'stoch_d': [],
|
|
'ema_9': [],
|
|
'ema_21': [],
|
|
'atr': []
|
|
}
|
|
self.features_dict = {}
|
|
|
|
if feature_extractors is None:
|
|
feature_extractors = []
|
|
self.feature_extractors = feature_extractors
|
|
|
|
# Environment parameters
|
|
self.initial_balance = initial_balance
|
|
self.balance = initial_balance
|
|
self.leverage = leverage
|
|
self.position = 'flat' # 'flat', 'long', or 'short'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
self.commission = commission
|
|
self.total_pnl = 0
|
|
self.total_fees = 0.0 # Track total fees paid
|
|
self.trades = []
|
|
self.trade_signals = []
|
|
self.current_step = 0
|
|
self.window_size = window_size
|
|
self.max_steps = max_steps
|
|
self.peak_balance = initial_balance
|
|
self.max_drawdown = 0
|
|
self.current_price = 0
|
|
self.win_count = 0
|
|
self.loss_count = 0
|
|
self.min_position_size = 100 # Minimum position size in USD
|
|
|
|
# Track candle patterns and reversal points
|
|
self.patterns = {}
|
|
self.reversal_points = []
|
|
|
|
# Define observation and action spaces
|
|
num_features = len(self.features) if hasattr(self, 'features') and self.features else 0
|
|
state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features
|
|
|
|
self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE
|
|
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
|
|
|
|
# Check if we have enough data
|
|
if len(self.data) < self.window_size:
|
|
logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}")
|
|
|
|
def calculate_reward(self, action):
|
|
"""Calculate reward based on the action taken"""
|
|
reward = 0
|
|
|
|
# Base reward structure
|
|
if self.position == 'flat':
|
|
if action == 0: # HOLD when flat
|
|
reward = 0.01 # Small reward for holding when no position
|
|
elif action == 1: # BUY/LONG
|
|
# Check for buy signal in CNN patterns
|
|
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
|
|
buy_confidence = self.cnn_patterns['long_confidence']
|
|
# Scale by confidence
|
|
reward = 0.1 * buy_confidence * 10
|
|
else:
|
|
reward = 0.1 # Default reward for taking a position
|
|
|
|
# Apply fee penalty
|
|
if self.position_size > 0:
|
|
fee = (self.position_size / 1900) * 1
|
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
|
reward -= fee_penalty
|
|
elif action == 2: # SELL/SHORT
|
|
# Check for sell signal in CNN patterns
|
|
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
|
|
sell_confidence = self.cnn_patterns['short_confidence']
|
|
# Scale by confidence
|
|
reward = 0.1 * sell_confidence * 10
|
|
else:
|
|
reward = 0.1 # Default reward for taking a position
|
|
|
|
# Apply fee penalty
|
|
if self.position_size > 0:
|
|
fee = (self.position_size / 1900) * 1
|
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
|
reward -= fee_penalty
|
|
elif action == 3: # CLOSE when no position
|
|
reward = -0.1 # Penalty for trying to close no position
|
|
|
|
elif self.position == 'long':
|
|
if action == 0: # HOLD long position
|
|
# Calculate price change since entry
|
|
price_change = (self.current_price - self.entry_price) / self.entry_price
|
|
|
|
# Reward or penalize based on price movement
|
|
if price_change > 0:
|
|
reward = price_change * 10 # Reward for holding profitable position
|
|
else:
|
|
reward = price_change * 5 # Smaller penalty for holding losing position
|
|
|
|
elif action == 1: # BUY when already long
|
|
reward = -0.1 # Penalty for redundant action
|
|
|
|
elif action == 2: # SELL when long (reversal)
|
|
# Calculate PnL
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
|
|
|
if pnl_percent > 0:
|
|
reward = -0.5 # Penalty for closing profitable long position to go short
|
|
else:
|
|
# Check for sell signal in CNN patterns
|
|
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
|
|
sell_confidence = self.cnn_patterns['short_confidence']
|
|
reward = 0.2 * sell_confidence * 10 # Reward for correct reversal
|
|
else:
|
|
reward = 0.2 # Default reward for cutting loss
|
|
|
|
# Apply fee penalty
|
|
if self.position_size > 0:
|
|
fee = (self.position_size / 1900) * 1
|
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
|
reward -= fee_penalty
|
|
|
|
elif action == 3: # CLOSE long position
|
|
# Calculate PnL
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
|
|
|
if pnl_percent > 0:
|
|
reward = pnl_percent * 15 # Higher reward for taking profit
|
|
else:
|
|
reward = pnl_percent * 5 # Smaller penalty for cutting loss
|
|
|
|
# Apply fee penalty
|
|
if self.position_size > 0:
|
|
fee = (self.position_size / 1900) * 1
|
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
|
reward -= fee_penalty
|
|
|
|
elif self.position == 'short':
|
|
if action == 0: # HOLD short position
|
|
# Calculate price change since entry
|
|
price_change = (self.entry_price - self.current_price) / self.entry_price
|
|
|
|
# Reward or penalize based on price movement
|
|
if price_change > 0:
|
|
reward = price_change * 10 # Reward for holding profitable position
|
|
else:
|
|
reward = price_change * 5 # Smaller penalty for holding losing position
|
|
|
|
elif action == 1: # BUY when short (reversal)
|
|
# Calculate PnL
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
|
|
|
if pnl_percent > 0:
|
|
reward = -0.5 # Penalty for closing profitable short position to go long
|
|
else:
|
|
# Check for buy signal in CNN patterns
|
|
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
|
|
buy_confidence = self.cnn_patterns['long_confidence']
|
|
reward = 0.2 * buy_confidence * 10 # Reward for correct reversal
|
|
else:
|
|
reward = 0.2 # Default reward for cutting loss
|
|
|
|
# Apply fee penalty
|
|
if self.position_size > 0:
|
|
fee = (self.position_size / 1900) * 1
|
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
|
reward -= fee_penalty
|
|
|
|
elif action == 2: # SELL when already short
|
|
reward = -0.1 # Penalty for redundant action
|
|
|
|
elif action == 3: # CLOSE short position
|
|
# Calculate PnL
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
|
|
|
if pnl_percent > 0:
|
|
reward = pnl_percent * 15 # Higher reward for taking profit
|
|
else:
|
|
reward = pnl_percent * 5 # Smaller penalty for cutting loss
|
|
|
|
# Apply fee penalty
|
|
if self.position_size > 0:
|
|
fee = (self.position_size / 1900) * 1
|
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
|
reward -= fee_penalty
|
|
|
|
return reward
|
|
|
|
def reset(self):
|
|
"""Reset the environment to its initial state and return the initial observation"""
|
|
self.balance = self.initial_balance
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
self.current_step = 0
|
|
self.trades = []
|
|
self.trade_signals = []
|
|
self.total_pnl = 0.0
|
|
self.total_fees = 0.0
|
|
self.peak_balance = self.initial_balance
|
|
self.max_drawdown = 0.0
|
|
self.win_count = 0
|
|
self.loss_count = 0
|
|
|
|
return self.get_state()
|
|
|
|
def add_data(self, candle):
|
|
"""Add a new candle to the data"""
|
|
# Check if candle is a list or dictionary
|
|
if isinstance(candle, list):
|
|
self.data_format_is_list = True
|
|
self.data.append(candle)
|
|
self.current_price = candle[4] # Close price is at index 4
|
|
else:
|
|
self.data_format_is_list = False
|
|
self.data.append(candle)
|
|
self.current_price = candle['close']
|
|
|
|
self._update_features()
|
|
|
|
def _initialize_features(self):
|
|
"""Initialize technical indicators and features"""
|
|
if len(self.data) < 30:
|
|
return
|
|
|
|
# Convert data to pandas DataFrame for easier calculation
|
|
if self.data_format_is_list:
|
|
# Convert list format to DataFrame
|
|
df = pd.DataFrame(self.data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
|
else:
|
|
# Dictionary format
|
|
df = pd.DataFrame(self.data)
|
|
|
|
# Basic price and volume
|
|
self.features['price'] = df['close'].values
|
|
self.features['volume'] = df['volume'].values
|
|
|
|
# Calculate RSI (14 periods)
|
|
delta = df['close'].diff()
|
|
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
|
|
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
|
|
rs = gain / loss
|
|
self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values
|
|
|
|
# Calculate MACD
|
|
ema12 = df['close'].ewm(span=12, adjust=False).mean()
|
|
ema26 = df['close'].ewm(span=26, adjust=False).mean()
|
|
macd = ema12 - ema26
|
|
signal = macd.ewm(span=9, adjust=False).mean()
|
|
self.features['macd'] = macd.values
|
|
self.features['macd_signal'] = signal.values
|
|
self.features['macd_hist'] = (macd - signal).values
|
|
|
|
# Calculate Bollinger Bands
|
|
sma20 = df['close'].rolling(window=20).mean()
|
|
std20 = df['close'].rolling(window=20).std()
|
|
self.features['bollinger_upper'] = (sma20 + 2 * std20).values
|
|
self.features['bollinger_mid'] = sma20.values
|
|
self.features['bollinger_lower'] = (sma20 - 2 * std20).values
|
|
|
|
# Calculate Stochastic Oscillator
|
|
low_14 = df['low'].rolling(window=14).min()
|
|
high_14 = df['high'].rolling(window=14).max()
|
|
k = 100 * ((df['close'] - low_14) / (high_14 - low_14))
|
|
self.features['stoch_k'] = k.values
|
|
self.features['stoch_d'] = k.rolling(window=3).mean().values
|
|
|
|
# Calculate EMAs
|
|
self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values
|
|
self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values
|
|
|
|
# Calculate ATR
|
|
high_low = df['high'] - df['low']
|
|
high_close = (df['high'] - df['close'].shift()).abs()
|
|
low_close = (df['low'] - df['close'].shift()).abs()
|
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values
|
|
|
|
def _update_features(self):
|
|
"""Update technical indicators with new data"""
|
|
self._initialize_features() # Recalculate all features
|
|
|
|
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""Fetch initial historical data for the environment"""
|
|
try:
|
|
logger.info(f"Fetching initial data for {symbol}")
|
|
|
|
# Use the refactored fetch method
|
|
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
|
|
|
|
# Update environment with fetched data
|
|
if data:
|
|
self.data = data
|
|
self._initialize_features()
|
|
logger.info(f"Initialized environment with {len(data)} candles")
|
|
else:
|
|
logger.warning("No initial data received")
|
|
|
|
return len(data) > 0
|
|
except Exception as e:
|
|
logger.error(f"Error fetching initial data: {e}")
|
|
return False
|
|
|
|
def step(self, action):
|
|
"""Take an action in the environment and return the next state, reward, and done flag"""
|
|
# Check if we have enough data
|
|
if self.current_step >= len(self.data) - 1:
|
|
# We've reached the end of data
|
|
done = True
|
|
next_state = self.get_state()
|
|
info = {
|
|
'action': 'none',
|
|
'price': self.current_price,
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'pnl': self.total_pnl
|
|
}
|
|
return next_state, 0, done, info
|
|
|
|
# Circuit breaker after consecutive losses
|
|
if self.count_consecutive_losses() >= 5:
|
|
logger.warning("Circuit breaker triggered after 5 consecutive losses")
|
|
return self.get_state(), -1, True, {'action': 'circuit_breaker_triggered'}
|
|
|
|
# Reduce leverage in volatile markets
|
|
if self.is_volatile_market():
|
|
self.leverage = MAX_LEVERAGE * 0.5 # Half leverage in volatile markets
|
|
else:
|
|
self.leverage = MAX_LEVERAGE
|
|
|
|
# Store current price before taking action
|
|
if self.data_format_is_list:
|
|
self.current_price = self.data[self.current_step][4] # Close price
|
|
else:
|
|
self.current_price = self.data[self.current_step]['close']
|
|
|
|
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
|
reward = self.calculate_reward(action)
|
|
|
|
# Execute the action
|
|
initial_balance = self.balance # Store initial balance to calculate PnL
|
|
|
|
# Open long position
|
|
if action == 1 and self.position != 'long':
|
|
if self.position == 'short':
|
|
# Close short position first
|
|
if self.position_size > 0:
|
|
# Calculate PnL
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
|
|
|
# Update balance and record trade
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Apply trading fee (1 USD per 1.9k position)
|
|
fee = (self.position_size / 1900) * 1
|
|
self.balance -= fee
|
|
self.total_fees += fee
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
if self.data_format_is_list:
|
|
timestamp = self.data[self.current_step][0] # Timestamp
|
|
else:
|
|
timestamp = self.data[self.current_step]['timestamp']
|
|
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'fee': fee,
|
|
'net_pnl': pnl_dollar - fee,
|
|
'duration': trade_duration,
|
|
'timestamp': timestamp,
|
|
'reason': 'action_change'
|
|
})
|
|
|
|
# Update win/loss count
|
|
if pnl_dollar > 0:
|
|
self.win_count += 1
|
|
else:
|
|
self.loss_count += 1
|
|
|
|
# Now open long position
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.entry_index = self.current_step
|
|
|
|
# Calculate position size with risk management
|
|
self.position_size = self.calculate_position_size()
|
|
|
|
# Apply trading fee (1 USD per 1.9k position)
|
|
fee = (self.position_size / 1900) * 1
|
|
self.balance -= fee
|
|
self.total_fees += fee
|
|
|
|
# Set stop loss and take profit
|
|
sl_percent = 0.02 # 2% stop loss
|
|
tp_percent = 0.04 # 4% take profit
|
|
|
|
self.stop_loss = self.entry_price * (1 - sl_percent)
|
|
self.take_profit = self.entry_price * (1 + tp_percent)
|
|
|
|
# Open short position
|
|
elif action == 2 and self.position != 'short':
|
|
if self.position == 'long':
|
|
# Close long position first
|
|
if self.position_size > 0:
|
|
# Calculate PnL
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
|
|
|
# Update balance and record trade
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Apply trading fee (1 USD per 1.9k position)
|
|
fee = (self.position_size / 1900) * 1
|
|
self.balance -= fee
|
|
self.total_fees += fee
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
if self.data_format_is_list:
|
|
timestamp = self.data[self.current_step][0] # Timestamp
|
|
else:
|
|
timestamp = self.data[self.current_step]['timestamp']
|
|
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'fee': fee,
|
|
'net_pnl': pnl_dollar - fee,
|
|
'duration': trade_duration,
|
|
'timestamp': timestamp,
|
|
'reason': 'action_change'
|
|
})
|
|
|
|
# Update win/loss count
|
|
if pnl_dollar > 0:
|
|
self.win_count += 1
|
|
else:
|
|
self.loss_count += 1
|
|
|
|
# Now open short position
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.entry_index = self.current_step
|
|
|
|
# Calculate position size with risk management
|
|
self.position_size = self.calculate_position_size()
|
|
|
|
# Apply trading fee (1 USD per 1.9k position)
|
|
fee = (self.position_size / 1900) * 1
|
|
self.balance -= fee
|
|
self.total_fees += fee
|
|
|
|
# Set stop loss and take profit
|
|
sl_percent = 0.02 # 2% stop loss
|
|
tp_percent = 0.04 # 4% take profit
|
|
|
|
self.stop_loss = self.entry_price * (1 + sl_percent)
|
|
self.take_profit = self.entry_price * (1 - tp_percent)
|
|
|
|
# Close position
|
|
elif action == 3 and self.position != 'flat':
|
|
if self.position == 'long':
|
|
# Calculate PnL
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
|
|
|
# Update balance and record trade
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Apply trading fee (1 USD per 1.9k position)
|
|
fee = (self.position_size / 1900) * 1
|
|
self.balance -= fee
|
|
self.total_fees += fee
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
if self.data_format_is_list:
|
|
timestamp = self.data[self.current_step][0] # Timestamp
|
|
else:
|
|
timestamp = self.data[self.current_step]['timestamp']
|
|
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'fee': fee,
|
|
'net_pnl': pnl_dollar - fee,
|
|
'duration': trade_duration,
|
|
'timestamp': timestamp,
|
|
'reason': 'close_action'
|
|
})
|
|
|
|
# Update win/loss count
|
|
if pnl_dollar > 0:
|
|
self.win_count += 1
|
|
else:
|
|
self.loss_count += 1
|
|
|
|
elif self.position == 'short':
|
|
# Calculate PnL
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
|
|
|
# Update balance and record trade
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Apply trading fee (1 USD per 1.9k position)
|
|
fee = (self.position_size / 1900) * 1
|
|
self.balance -= fee
|
|
self.total_fees += fee
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
if self.data_format_is_list:
|
|
timestamp = self.data[self.current_step][0] # Timestamp
|
|
else:
|
|
timestamp = self.data[self.current_step]['timestamp']
|
|
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'fee': fee,
|
|
'net_pnl': pnl_dollar - fee,
|
|
'duration': trade_duration,
|
|
'timestamp': timestamp,
|
|
'reason': 'close_action'
|
|
})
|
|
|
|
# Update win/loss count
|
|
if pnl_dollar > 0:
|
|
self.win_count += 1
|
|
else:
|
|
self.loss_count += 1
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Record trade signal for visualization
|
|
if action > 0: # If not HOLD
|
|
signal_type = None
|
|
if action == 1: # BUY/LONG
|
|
signal_type = 'buy'
|
|
elif action == 2: # SELL/SHORT
|
|
signal_type = 'sell'
|
|
elif action == 3: # CLOSE
|
|
if self.position == 'long':
|
|
signal_type = 'close_long'
|
|
elif self.position == 'short':
|
|
signal_type = 'close_short'
|
|
|
|
if signal_type:
|
|
if self.data_format_is_list:
|
|
timestamp = self.data[self.current_step][0] # Timestamp
|
|
else:
|
|
timestamp = self.data[self.current_step]['timestamp']
|
|
|
|
self.trade_signals.append({
|
|
'timestamp': timestamp,
|
|
'price': self.current_price,
|
|
'type': signal_type,
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Check for stop loss / take profit hits
|
|
self.check_sl_tp()
|
|
|
|
# Move to next step
|
|
self.current_step += 1
|
|
done = self.current_step >= len(self.data) - 1
|
|
|
|
# Get new state
|
|
next_state = self.get_state()
|
|
|
|
# Update peak balance and drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
|
|
current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0
|
|
self.max_drawdown = max(self.max_drawdown, current_drawdown)
|
|
|
|
# Create info dictionary
|
|
info = {
|
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
|
'price': self.current_price,
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'pnl': self.total_pnl,
|
|
'fees': self.total_fees,
|
|
'net_pnl': self.total_pnl - self.total_fees
|
|
}
|
|
|
|
return next_state, reward, done, info
|
|
|
|
def check_sl_tp(self):
|
|
"""Check if stop loss or take profit has been hit with improved trailing stop"""
|
|
if self.position == 'flat':
|
|
return
|
|
|
|
if self.position == 'long':
|
|
# Implement trailing stop loss if in profit
|
|
if self.current_price > self.entry_price * 1.01:
|
|
self.stop_loss = max(self.stop_loss, self.current_price * 0.995) # Trail at 0.5% below current price
|
|
|
|
# Check stop loss
|
|
if self.current_price <= self.stop_loss:
|
|
# Stop loss hit
|
|
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.stop_loss,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'reason': 'stop_loss'
|
|
})
|
|
|
|
if pnl_dollar > 0:
|
|
self.win_count += 1
|
|
else:
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Check take profit
|
|
elif self.current_price >= self.take_profit:
|
|
# Take profit hit
|
|
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.take_profit,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'reason': 'take_profit'
|
|
})
|
|
|
|
self.win_count += 1
|
|
|
|
logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
elif self.position == 'short':
|
|
# Implement trailing stop loss if in profit
|
|
if self.current_price < self.entry_price * 0.99:
|
|
self.stop_loss = min(self.stop_loss, self.current_price * 1.005) # Trail at 0.5% above current price
|
|
|
|
# Check stop loss
|
|
if self.current_price >= self.stop_loss:
|
|
# Stop loss hit
|
|
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.stop_loss,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'reason': 'stop_loss'
|
|
})
|
|
|
|
if pnl_dollar > 0:
|
|
self.win_count += 1
|
|
else:
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Check take profit
|
|
elif self.current_price <= self.take_profit:
|
|
# Take profit hit
|
|
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = self.current_step - self.entry_index
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.take_profit,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'reason': 'take_profit'
|
|
})
|
|
|
|
self.win_count += 1
|
|
|
|
logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
def get_state(self):
|
|
"""Create state representation for the agent with enhanced features"""
|
|
# Ensure we have enough data
|
|
if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0:
|
|
# Return zeros if not enough data
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Create a normalized state vector with recent price action and indicators
|
|
state_components = []
|
|
|
|
# Safely get the latest price
|
|
try:
|
|
latest_price = self.features['price'][-1]
|
|
except IndexError:
|
|
# If we can't get the latest price, return zeros
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Safely get price features
|
|
try:
|
|
# Price features (normalize recent prices by the latest price)
|
|
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
|
state_components.append(price_features)
|
|
except (IndexError, ZeroDivisionError):
|
|
# If we can't get price features, use zeros
|
|
state_components.append(np.zeros(10))
|
|
|
|
# Safely get volume features
|
|
try:
|
|
# Volume features (normalize by max volume)
|
|
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
|
vol_features = np.array(self.features['volume'][-5:]) / max_vol
|
|
state_components.append(vol_features)
|
|
except (IndexError, ZeroDivisionError):
|
|
# If we can't get volume features, use zeros
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Technical indicators
|
|
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
|
|
state_components.append(rsi)
|
|
|
|
# MACD (normalize)
|
|
macd_vals = np.array(self.features['macd'][-3:])
|
|
macd_signal = np.array(self.features['macd_signal'][-3:])
|
|
macd_hist = np.array(self.features['macd_hist'][-3:])
|
|
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
|
|
macd_norm = macd_vals / macd_scale
|
|
macd_signal_norm = macd_signal / macd_scale
|
|
macd_hist_norm = macd_hist / macd_scale
|
|
|
|
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
|
|
|
|
# Bollinger position (where is price relative to bands)
|
|
bb_upper = np.array(self.features['bollinger_upper'][-3:])
|
|
bb_lower = np.array(self.features['bollinger_lower'][-3:])
|
|
bb_mid = np.array(self.features['bollinger_mid'][-3:])
|
|
price = np.array(self.features['price'][-3:])
|
|
|
|
# Calculate position of price within Bollinger Bands (0 to 1)
|
|
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
|
state_components.append(np.array(bb_pos))
|
|
|
|
# Stochastic oscillator
|
|
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
|
|
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
|
|
|
|
# Add predicted prices (if available)
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Normalize predictions relative to current price
|
|
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
|
|
state_components.append(pred_norm)
|
|
else:
|
|
# Add zeros if no predictions
|
|
state_components.append(np.zeros(3))
|
|
|
|
# Add extrema signals (if available)
|
|
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
|
|
# Get recent signals
|
|
idx = len(self.optimal_signals) - 5
|
|
if idx < 0:
|
|
idx = 0
|
|
recent_signals = self.optimal_signals[idx:idx+5]
|
|
# Pad if needed
|
|
if len(recent_signals) < 5:
|
|
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
|
|
state_components.append(recent_signals)
|
|
else:
|
|
# Add zeros if no signals
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Position info
|
|
position_info = np.zeros(5)
|
|
if self.position == 'long':
|
|
position_info[0] = 1.0 # Position is long
|
|
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
elif self.position == 'short':
|
|
position_info[0] = -1.0 # Position is short
|
|
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
|
|
state_components.append(position_info)
|
|
|
|
# NEW FEATURES START HERE
|
|
|
|
# 1. Price momentum features (rate of change)
|
|
if len(self.features['price']) >= 20:
|
|
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
|
|
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
|
|
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
|
|
momentum_features = np.array([roc_5, roc_10, roc_20])
|
|
state_components.append(momentum_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 2. Volatility features
|
|
if len(self.features['price']) >= 20:
|
|
# Calculate price returns
|
|
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
|
|
# Calculate volatility (standard deviation of returns)
|
|
volatility = np.std(returns)
|
|
|
|
# Calculate normalized high-low range based on data format
|
|
if self.data_format_is_list:
|
|
# List format: high at index 2, low at index 3, close at index 4
|
|
high_low_range = np.mean([
|
|
(self.data[i][2] - self.data[i][3]) / self.data[i][4]
|
|
for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1))
|
|
]) if len(self.data) > 0 else 0
|
|
else:
|
|
# Dictionary format
|
|
high_low_range = np.mean([
|
|
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
|
|
for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1))
|
|
]) if len(self.data) > 0 else 0
|
|
|
|
# ATR normalized by price
|
|
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
|
|
|
|
volatility_features = np.array([volatility, high_low_range, atr_norm])
|
|
state_components.append(volatility_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 3. Market regime features
|
|
if len(self.features['price']) >= 50:
|
|
# Trend strength (ADX-like measure)
|
|
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
|
|
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
|
|
trend_strength = abs(ema9 - ema21) / ema21
|
|
|
|
# Detect if in range or trending
|
|
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
|
|
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
|
|
|
|
# Detect if near support/resistance
|
|
near_support = 1.0 if self.is_near_support() else 0.0
|
|
near_resistance = 1.0 if self.is_near_resistance() else 0.0
|
|
|
|
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
|
|
state_components.append(market_regime)
|
|
else:
|
|
state_components.append(np.zeros(5))
|
|
|
|
# 4. Trade history features
|
|
if len(self.trades) > 0:
|
|
# Recent win/loss ratio
|
|
recent_trades = self.trades[-min(10, len(self.trades)):]
|
|
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
|
|
|
|
# Average profit/loss
|
|
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
|
|
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
|
|
|
|
# Normalize by balance
|
|
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
|
|
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
|
|
|
|
# Last trade result
|
|
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
|
|
|
|
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
|
|
state_components.append(trade_history)
|
|
else:
|
|
state_components.append(np.zeros(4))
|
|
|
|
# Combine all features
|
|
state = np.concatenate([comp.flatten() for comp in state_components])
|
|
|
|
# Replace any NaN or infinite values
|
|
state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# Ensure the state has the correct size
|
|
if len(state) != STATE_SIZE:
|
|
logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}")
|
|
# Pad or truncate to match expected size
|
|
if len(state) < STATE_SIZE:
|
|
state = np.pad(state, (0, STATE_SIZE - len(state)))
|
|
else:
|
|
state = state[:STATE_SIZE]
|
|
|
|
return state
|
|
|
|
def get_expanded_state_size(self):
|
|
"""Calculate the size of the expanded state representation"""
|
|
# Create a dummy state to get its size
|
|
state = self.get_state()
|
|
return len(state)
|
|
|
|
async def expand_model_with_new_features(agent, env):
|
|
"""Expand the model to handle new features without retraining from scratch"""
|
|
# Get the new state size
|
|
new_state_size = env.get_expanded_state_size()
|
|
|
|
# Only expand if the new state size is larger
|
|
if new_state_size > agent.state_size:
|
|
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
|
|
|
|
# Expand the model
|
|
success = agent.expand_model(
|
|
new_state_size=new_state_size,
|
|
new_hidden_size=512, # Increase hidden size for more capacity
|
|
new_lstm_layers=3, # More layers for deeper patterns
|
|
new_attention_heads=8 # More attention heads for complex relationships
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"Model successfully expanded to handle {new_state_size} features")
|
|
return True
|
|
else:
|
|
logger.error("Failed to expand model")
|
|
return False
|
|
else:
|
|
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
|
|
return True
|
|
|
|
|
|
def calculate_reward(self, action):
|
|
"""
|
|
Calculate reward for taking the given action.
|
|
|
|
Args:
|
|
action: The action taken (0=hold, 1=buy, 2=sell, 3=close)
|
|
|
|
Returns:
|
|
The calculated reward
|
|
"""
|
|
reward = 0
|
|
|
|
# Get current price
|
|
if self.data_format_is_list:
|
|
current_price = self.data[self.current_step][4] # Close price
|
|
else:
|
|
current_price = self.data[self.current_step]['close']
|
|
|
|
# Base reward component based on price movement
|
|
price_change_pct = 0
|
|
if self.current_step > 0:
|
|
if self.data_format_is_list:
|
|
prev_price = self.data[self.current_step-1][4] # Previous close
|
|
else:
|
|
prev_price = self.data[self.current_step-1]['close']
|
|
|
|
price_change_pct = (current_price - prev_price) / prev_price
|
|
|
|
# Check if we have CNN patterns available
|
|
pattern_confidence = 0
|
|
if hasattr(self, 'cnn_patterns'):
|
|
if action == 1 and 'long_confidence' in self.cnn_patterns: # Buy action
|
|
pattern_confidence = self.cnn_patterns['long_confidence']
|
|
elif action == 2 and 'short_confidence' in self.cnn_patterns: # Sell action
|
|
pattern_confidence = self.cnn_patterns['short_confidence']
|
|
|
|
# Action-specific rewards
|
|
if action == 0: # HOLD
|
|
# Small positive reward for holding in the right direction of market movement
|
|
if self.position == 'long' and price_change_pct > 0:
|
|
reward += 0.1 + price_change_pct * 10
|
|
elif self.position == 'short' and price_change_pct < 0:
|
|
reward += 0.1 + abs(price_change_pct) * 10
|
|
else:
|
|
# Small negative reward for holding in the wrong direction
|
|
reward -= 0.1
|
|
elif action == 1 or action == 2: # BUY or SELL
|
|
# Apply trading fee as negative reward (1 USD per 1.9k position size)
|
|
position_size = self.calculate_position_size()
|
|
fee = (position_size / 1900) * 1 # Trading fee in USD
|
|
|
|
# Penalty for fee
|
|
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
|
|
reward -= fee_penalty
|
|
|
|
# Logging
|
|
if hasattr(self, 'total_fees'):
|
|
self.total_fees += fee
|
|
else:
|
|
self.total_fees = fee
|
|
elif action == 3: # CLOSE
|
|
# Apply trading fee as negative reward (1 USD per 1.9k position size)
|
|
fee = (self.position_size / 1900) * 1 # Trading fee in USD
|
|
|
|
# Penalty for fee
|
|
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
|
|
reward -= fee_penalty
|
|
|
|
# Logging
|
|
if hasattr(self, 'total_fees'):
|
|
self.total_fees += fee
|
|
else:
|
|
self.total_fees = fee
|
|
|
|
# Add CNN pattern confidence to reward
|
|
reward += pattern_confidence * 10
|
|
|
|
return reward
|
|
|
|
async def initialize_futures(self, exchange):
|
|
"""Initialize futures trading parameters"""
|
|
if not self.demo:
|
|
try:
|
|
# Set up futures trading parameters
|
|
await exchange.set_position_mode(True) # Hedge mode
|
|
await exchange.set_margin_mode("cross", symbol=self.futures_symbol)
|
|
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
|
|
logger.info(f"Futures initialized with {self.leverage}x leverage")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize futures trading: {str(e)}")
|
|
logger.info("Falling back to demo mode for safety")
|
|
demo = True
|
|
|
|
async def execute_real_trade(self, exchange, action, current_price):
|
|
"""Execute real futures trade on MEXC"""
|
|
try:
|
|
position_size = self.calculate_position_size()
|
|
|
|
if action == 1: # Open long
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='buy',
|
|
amount=position_size,
|
|
params={'positionSide': 'LONG'}
|
|
)
|
|
logger.info(f"Opened LONG position: {order}")
|
|
|
|
elif action == 2: # Open short
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='sell',
|
|
amount=position_size,
|
|
params={'positionSide': 'SHORT'}
|
|
)
|
|
logger.info(f"Opened SHORT position: {order}")
|
|
|
|
elif action == 3: # Close position
|
|
position_side = 'LONG' if self.position == 'long' else 'SHORT'
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='sell' if position_side == 'LONG' else 'buy',
|
|
amount=self.position_size,
|
|
params={'positionSide': position_side}
|
|
)
|
|
logger.info(f"Closed {position_side} position: {order}")
|
|
|
|
return order
|
|
except Exception as e:
|
|
logger.error(f"Trade execution failed: {e}")
|
|
return None
|
|
|
|
def trades_in_last_n_candles(self, n=20):
|
|
"""Count the number of trades in the last n candles"""
|
|
if len(self.trades) == 0:
|
|
return 0
|
|
|
|
if self.data_format_is_list:
|
|
# List format: timestamp at index 0
|
|
current_time = self.data[self.current_step][0]
|
|
n_candles_ago = self.data[max(0, self.current_step - n)][0]
|
|
else:
|
|
# Dictionary format
|
|
current_time = self.data[self.current_step]['timestamp']
|
|
n_candles_ago = self.data[max(0, self.current_step - n)]['timestamp']
|
|
|
|
count = 0
|
|
for trade in reversed(self.trades):
|
|
if 'timestamp' in trade and trade['timestamp'] >= n_candles_ago and trade['timestamp'] <= current_time:
|
|
count += 1
|
|
else:
|
|
# Older trades, we can stop counting
|
|
break
|
|
|
|
return count
|
|
|
|
def count_consecutive_losses(self):
|
|
"""Count the number of consecutive losing trades"""
|
|
count = 0
|
|
for trade in reversed(self.trades):
|
|
if trade.get('pnl_dollar', 0) < 0:
|
|
count += 1
|
|
else:
|
|
break
|
|
return count
|
|
|
|
def is_volatile_market(self):
|
|
"""Determine if the current market is volatile"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
recent_prices = self.features['price'][-20:]
|
|
avg_price = sum(recent_prices) / len(recent_prices)
|
|
volatility = sum([abs(p - avg_price) / avg_price for p in recent_prices]) / len(recent_prices)
|
|
|
|
return volatility > 0.01 # 1% average deviation is considered volatile
|
|
|
|
def is_uptrend(self):
|
|
"""Determine if the market is in an uptrend"""
|
|
if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2:
|
|
return False
|
|
|
|
# Short-term trend
|
|
short_trend = self.features['ema_9'][-1] > self.features['ema_9'][-2]
|
|
|
|
# Medium-term trend
|
|
medium_trend = self.features['ema_9'][-1] > self.features['ema_21'][-1]
|
|
|
|
return short_trend and medium_trend
|
|
|
|
def is_downtrend(self):
|
|
"""Determine if the market is in a downtrend"""
|
|
if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2:
|
|
return False
|
|
|
|
# Short-term trend
|
|
short_trend = self.features['ema_9'][-1] < self.features['ema_9'][-2]
|
|
|
|
# Medium-term trend
|
|
medium_trend = self.features['ema_9'][-1] < self.features['ema_21'][-1]
|
|
|
|
return short_trend and medium_trend
|
|
|
|
def calculate_position_size(self):
|
|
"""Calculate position size based on risk management rules"""
|
|
# Reduce position size after losses
|
|
consecutive_losses = self.count_consecutive_losses()
|
|
risk_factor = max(0.3, 1.0 - (consecutive_losses * 0.1)) # Reduce by 10% per loss, min 30%
|
|
|
|
# Calculate position size based on available balance and risk
|
|
max_risk_amount = self.balance * 0.02 # Risk 2% per trade
|
|
position_size = max_risk_amount / (STOP_LOSS_PERCENT / 100 * self.current_price)
|
|
|
|
# Apply leverage
|
|
position_size = position_size * self.leverage
|
|
|
|
# Cap at available balance
|
|
position_size = min(position_size, self.balance * self.leverage)
|
|
|
|
return position_size * risk_factor
|
|
|
|
# ... existing identify_optimal_trades method ...
|
|
|
|
def update_cnn_patterns(self, candle_data=None):
|
|
"""
|
|
Update CNN patterns using multi-timeframe data.
|
|
|
|
Args:
|
|
candle_data: Dictionary containing candle data for different timeframes
|
|
"""
|
|
if not candle_data:
|
|
return
|
|
|
|
try:
|
|
# Check if we have the necessary timeframes
|
|
required_timeframes = ['1m', '1h', '1d']
|
|
if not all(tf in candle_data for tf in required_timeframes):
|
|
logging.warning(f"Missing required timeframes for CNN pattern detection")
|
|
return
|
|
|
|
# Initialize patterns if not already done
|
|
if not hasattr(self, 'cnn_patterns'):
|
|
self.cnn_patterns = {}
|
|
|
|
# Extract features from candle data
|
|
features = {}
|
|
|
|
# Process each timeframe
|
|
for tf in required_timeframes:
|
|
candles = candle_data[tf]
|
|
if not candles or len(candles) < 30:
|
|
continue
|
|
|
|
# Convert to numpy arrays for easier processing
|
|
closes = np.array([c[4] for c in candles[-100:]])
|
|
highs = np.array([c[2] for c in candles[-100:]])
|
|
lows = np.array([c[3] for c in candles[-100:]])
|
|
|
|
# Simple feature extraction
|
|
# 1. Detect trends
|
|
ema20 = self._calculate_ema(closes, 20)
|
|
ema50 = self._calculate_ema(closes, 50)
|
|
|
|
uptrend = ema20[-1] > ema50[-1] and closes[-1] > ema20[-1]
|
|
downtrend = ema20[-1] < ema50[-1] and closes[-1] < ema20[-1]
|
|
|
|
# 2. Detect potential reversal patterns
|
|
# Find local extrema
|
|
tops, bottoms = find_local_extrema(closes, window=10)
|
|
|
|
# Check if we're near a bottom (potential buy)
|
|
near_bottom = False
|
|
bottom_confidence = 0
|
|
if bottoms and len(bottoms) > 0:
|
|
last_bottom = bottoms[-1]
|
|
if len(closes) - last_bottom < 5: # Recent bottom
|
|
bottom_dist = abs(closes[-1] - closes[last_bottom]) / closes[last_bottom]
|
|
if bottom_dist < 0.01: # Within 1% of the bottom
|
|
near_bottom = True
|
|
# Higher confidence if volume is increasing
|
|
bottom_confidence = 0.8 - bottom_dist * 50 # 0.8 to 0.3 range
|
|
|
|
# Check if we're near a top (potential sell)
|
|
near_top = False
|
|
top_confidence = 0
|
|
if tops and len(tops) > 0:
|
|
last_top = tops[-1]
|
|
if len(closes) - last_top < 5: # Recent top
|
|
top_dist = abs(closes[-1] - closes[last_top]) / closes[last_top]
|
|
if top_dist < 0.01: # Within 1% of the top
|
|
near_top = True
|
|
# Higher confidence if volume is increasing
|
|
top_confidence = 0.8 - top_dist * 50 # 0.8 to 0.3 range
|
|
|
|
# Store features for this timeframe
|
|
features[tf] = {
|
|
'uptrend': uptrend,
|
|
'downtrend': downtrend,
|
|
'near_bottom': near_bottom,
|
|
'bottom_confidence': bottom_confidence,
|
|
'near_top': near_top,
|
|
'top_confidence': top_confidence
|
|
}
|
|
|
|
# Combine features across timeframes to get overall pattern confidence
|
|
long_confidence = 0
|
|
short_confidence = 0
|
|
|
|
# Weight each timeframe (higher weight for longer timeframes)
|
|
weights = {'1m': 0.2, '1h': 0.3, '1d': 0.5}
|
|
|
|
for tf, tf_features in features.items():
|
|
weight = weights.get(tf, 0.2)
|
|
|
|
# Add to long confidence
|
|
if tf_features['uptrend'] or tf_features['near_bottom']:
|
|
long_confidence += weight * (0.6 if tf_features['uptrend'] else 0) + \
|
|
weight * (tf_features['bottom_confidence'] if tf_features['near_bottom'] else 0)
|
|
|
|
# Add to short confidence
|
|
if tf_features['downtrend'] or tf_features['near_top']:
|
|
short_confidence += weight * (0.6 if tf_features['downtrend'] else 0) + \
|
|
weight * (tf_features['top_confidence'] if tf_features['near_top'] else 0)
|
|
|
|
# Normalize confidence scores to [0, 1]
|
|
long_confidence = min(1.0, long_confidence)
|
|
short_confidence = min(1.0, short_confidence)
|
|
|
|
# Update patterns
|
|
self.cnn_patterns = {
|
|
'long_confidence': long_confidence,
|
|
'short_confidence': short_confidence,
|
|
'features': features
|
|
}
|
|
|
|
logging.debug(f"Updated CNN patterns - Long: {long_confidence:.2f}, Short: {short_confidence:.2f}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error updating CNN patterns: {e}")
|
|
|
|
def _calculate_ema(self, data, span):
|
|
"""Calculate exponential moving average"""
|
|
alpha = 2 / (span + 1)
|
|
alpha_rev = 1 - alpha
|
|
|
|
ema = np.zeros_like(data)
|
|
ema[0] = data[0]
|
|
|
|
for i in range(1, len(data)):
|
|
ema[i] = alpha * data[i] + alpha_rev * ema[i-1]
|
|
|
|
return ema
|
|
|
|
def is_uncertain_market(self):
|
|
"""Determine if the market is in an uncertain/range-bound state"""
|
|
if len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Check if EMAs are close to each other (no clear trend)
|
|
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
|
|
ema9 = self.features['ema_9'][-1]
|
|
ema21 = self.features['ema_21'][-1]
|
|
|
|
# If EMAs are within 0.2% of each other, market is uncertain
|
|
if abs(ema9 - ema21) / ema21 < 0.002:
|
|
return True
|
|
|
|
# Check if price is oscillating without clear direction
|
|
if len(self.features['price']) >= 10:
|
|
recent_prices = self.features['price'][-10:]
|
|
ups = downs = 0
|
|
for i in range(1, len(recent_prices)):
|
|
if recent_prices[i] > recent_prices[i-1]:
|
|
ups += 1
|
|
else:
|
|
downs += 1
|
|
|
|
# If there's a mix of ups and downs (neither dominates heavily)
|
|
return abs(ups - downs) < 3
|
|
|
|
return False
|
|
|
|
def is_near_support(self):
|
|
"""Determine if the current price is near a support level"""
|
|
if len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Use Bollinger lower band as support
|
|
if len(self.features['bollinger_lower']) > 0 and len(self.features['price']) > 0:
|
|
current_price = self.features['price'][-1]
|
|
lower_band = self.features['bollinger_lower'][-1]
|
|
|
|
# If price is within 0.5% of the lower band
|
|
if (current_price - lower_band) / current_price < 0.005:
|
|
return True
|
|
|
|
# Check if we're near recent lows
|
|
if len(self.features['price']) >= 20:
|
|
current_price = self.features['price'][-1]
|
|
min_price = min(self.features['price'][-20:])
|
|
|
|
# If within 1% of recent lows
|
|
if (current_price - min_price) / current_price < 0.01:
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_near_resistance(self):
|
|
"""Determine if the current price is near a resistance level"""
|
|
if len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Use Bollinger upper band as resistance
|
|
if len(self.features['bollinger_upper']) > 0 and len(self.features['price']) > 0:
|
|
current_price = self.features['price'][-1]
|
|
upper_band = self.features['bollinger_upper'][-1]
|
|
|
|
# If price is within 0.5% of the upper band
|
|
if (upper_band - current_price) / current_price < 0.005:
|
|
return True
|
|
|
|
# Check if we're near recent highs
|
|
if len(self.features['price']) >= 20:
|
|
current_price = self.features['price'][-1]
|
|
max_price = max(self.features['price'][-20:])
|
|
|
|
# If within 1% of recent highs
|
|
if (max_price - current_price) / current_price < 0.01:
|
|
return True
|
|
|
|
return False
|
|
|
|
def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'):
|
|
"""
|
|
Add a candlestick chart and metrics to TensorBoard
|
|
|
|
Parameters:
|
|
- writer: TensorBoard writer
|
|
- step: Current step
|
|
- title: Title for the chart
|
|
"""
|
|
try:
|
|
# Initialize writer if not provided
|
|
if writer is None:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
writer = SummaryWriter()
|
|
|
|
# Log basic metrics
|
|
writer.add_scalar('Balance', self.balance, step)
|
|
writer.add_scalar('Total_PnL', self.total_pnl, step)
|
|
|
|
# Log total fees if available
|
|
if hasattr(self, 'total_fees'):
|
|
writer.add_scalar('Total_Fees', self.total_fees, step)
|
|
writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step)
|
|
|
|
# Log position info
|
|
writer.add_scalar('Position_Size', self.position_size, step)
|
|
|
|
# Log drawdown and win rate
|
|
writer.add_scalar('Max_Drawdown', self.max_drawdown, step)
|
|
|
|
win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0
|
|
writer.add_scalar('Win_Rate', win_rate, step)
|
|
|
|
# Log trade count
|
|
writer.add_scalar('Trade_Count', len(self.trades), step)
|
|
|
|
# Check if we have enough data for candlestick chart
|
|
if len(self.data) <= 0:
|
|
logger.warning("No data available for candlestick chart")
|
|
return
|
|
|
|
# Create figure for candlestick chart (last 100 data points)
|
|
start_idx = max(0, self.current_step - 100)
|
|
end_idx = self.current_step
|
|
|
|
# Get recent trades for visualization (last 10 trades)
|
|
recent_trades = self.trades[-10:] if self.trades else []
|
|
|
|
try:
|
|
fig = create_candlestick_figure(
|
|
self.data[start_idx:end_idx+1],
|
|
title=title,
|
|
trades=recent_trades
|
|
)
|
|
|
|
# Add figure to TensorBoard
|
|
writer.add_figure('Candlestick_Chart', fig, step)
|
|
|
|
# Close figure to free memory
|
|
plt.close(fig)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating candlestick chart: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding chart to TensorBoard: {e}")
|
|
# Continue execution even if chart fails
|
|
|
|
def get_realtime_state(self, tick_data):
|
|
"""
|
|
Create a state representation optimized for real-time processing.
|
|
This is a streamlined version of get_state() designed for minimal latency.
|
|
|
|
TODO: Implement optimized state creation from tick data
|
|
"""
|
|
# This would be a simplified version of get_state that processes only
|
|
# the most important features needed for real-time decision making
|
|
|
|
# Example implementation:
|
|
# realtime_features = {
|
|
# 'price': tick_data['price'],
|
|
# 'volume': tick_data['volume'],
|
|
# 'ema_short': self._calculate_ema(tick_data['price'], 9),
|
|
# 'ema_long': self._calculate_ema(tick_data['price'], 21),
|
|
# }
|
|
|
|
# Convert to tensor or numpy array in the required format
|
|
# return torch.tensor([...], dtype=torch.float32)
|
|
|
|
# Placeholder
|
|
return np.zeros((self.observation_space.shape[0],), dtype=np.float32)
|
|
|
|
# Ensure GPU usage if available
|
|
def get_device():
|
|
"""Get the best available device (CUDA GPU or CPU)"""
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
# Set up for mixed precision training
|
|
torch.backends.cudnn.benchmark = True
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("GPU not available, using CPU")
|
|
return device
|
|
|
|
# Update Agent class to use GPU properly
|
|
class Agent:
|
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
|
|
"""Initialize Agent with architecture parameters stored as attributes"""
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
|
|
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
|
|
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
|
|
|
|
# Set device
|
|
self.device = device if device is not None else get_device()
|
|
|
|
# Initialize networks - use LSTMAttentionDQN instead of DQN
|
|
self.policy_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
|
self.target_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Initialize optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
|
|
|
# Initialize replay memory
|
|
self.memory = ReplayMemory(MEMORY_SIZE)
|
|
|
|
# Initialize exploration parameters
|
|
self.epsilon = EPSILON_START
|
|
self.epsilon_start = EPSILON_START
|
|
self.epsilon_end = EPSILON_END
|
|
self.epsilon_decay = EPSILON_DECAY
|
|
self.epsilon_min = EPSILON_END
|
|
|
|
# Initialize step counter
|
|
self.steps_done = 0
|
|
|
|
# Initialize TensorBoard writer
|
|
self.writer = None
|
|
|
|
# Initialize GradScaler for mixed precision training
|
|
self.scaler = torch.amp.GradScaler('cuda') if self.device.type == "cuda" else None
|
|
|
|
# Initialize candle cache for multi-timeframe data
|
|
self.candle_cache = CandleCache()
|
|
|
|
# Store model name for logging
|
|
self.model_name = f"LSTM_Attention_DQN_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
|
|
logger.info(f"Initialized agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
|
"""Expand the model to handle more features or increase capacity"""
|
|
logger.info(f"Expanding model: {self.state_size} → {new_state_size}, "
|
|
f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}")
|
|
|
|
# Save old weights
|
|
old_state_dict = self.policy_net.state_dict()
|
|
|
|
# Create new larger networks
|
|
new_policy_net = LSTMAttentionDQN(new_state_size, self.action_size,
|
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
|
new_target_net = LSTMAttentionDQN(new_state_size, self.action_size,
|
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
|
|
|
# Transfer weights for common layers
|
|
new_state_dict = new_policy_net.state_dict()
|
|
for name, param in old_state_dict.items():
|
|
if name in new_state_dict:
|
|
# If shapes match, copy directly
|
|
if new_state_dict[name].shape == param.shape:
|
|
new_state_dict[name] = param
|
|
# For first layer, copy weights for the original input dimensions
|
|
elif name == "fc1.weight":
|
|
new_state_dict[name][:, :self.state_size] = param
|
|
# For other layers, initialize with a strategy that preserves scale
|
|
else:
|
|
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
|
|
|
|
# Load transferred weights
|
|
new_policy_net.load_state_dict(new_state_dict)
|
|
new_target_net.load_state_dict(new_state_dict)
|
|
|
|
# Replace networks
|
|
self.policy_net = new_policy_net
|
|
self.target_net = new_target_net
|
|
self.target_net.eval()
|
|
|
|
# Update optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
|
|
|
# Update state size
|
|
self.state_size = new_state_size
|
|
|
|
# Print new model size
|
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
|
logger.info(f"New model size: {total_params:,} parameters")
|
|
|
|
return True
|
|
|
|
def select_action(self, state, training=True, candle_data=None):
|
|
"""
|
|
Select an action using the policy network.
|
|
|
|
Args:
|
|
state: The current state
|
|
training: Whether we're in training mode (for epsilon-greedy)
|
|
candle_data: Dictionary with ['1s'-later], '1m', '1h', '1d' candle data
|
|
|
|
Returns:
|
|
The selected action
|
|
"""
|
|
# ... existing code ...
|
|
|
|
# Add CNN processing if candle data is available
|
|
cnn_inputs = None
|
|
if candle_data and all(k in candle_data for k in [ '1m', '1h', '1d']):
|
|
# Process candle data into tensors
|
|
# x_1s = self.prepare_candle_tensor(candle_data['1s'])
|
|
x_1m = self.prepare_candle_tensor(candle_data['1m'])
|
|
x_1h = self.prepare_candle_tensor(candle_data['1h'])
|
|
x_1d = self.prepare_candle_tensor(candle_data['1d'])
|
|
|
|
cnn_inputs = (x_1m, x_1h, x_1d)
|
|
|
|
# Use epsilon-greedy strategy during training
|
|
if training and random.random() < self.epsilon:
|
|
return random.randrange(self.action_size)
|
|
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(state).to(self.device)
|
|
|
|
if cnn_inputs:
|
|
q_values = self.policy_net(state_tensor, *cnn_inputs)
|
|
else:
|
|
q_values = self.policy_net(state_tensor)
|
|
|
|
return q_values.max(1)[1].item()
|
|
|
|
def prepare_candle_tensor(self, candles, max_candles=300):
|
|
"""Convert candle data to tensors for CNN input"""
|
|
if not candles:
|
|
# Return zeros if no candles available
|
|
return torch.zeros((1, 5, max_candles), device=self.device)
|
|
|
|
# Limit to the most recent candles
|
|
candles = candles[-max_candles:]
|
|
|
|
# Extract OHLCV data
|
|
ohlcv = np.array([[c[1], c[2], c[3], c[4], c[5]] for c in candles], dtype=np.float32)
|
|
|
|
# Normalize the data
|
|
if len(ohlcv) > 0:
|
|
# Simple min-max normalization per column
|
|
min_vals = ohlcv.min(axis=0, keepdims=True)
|
|
max_vals = ohlcv.max(axis=0, keepdims=True)
|
|
range_vals = max_vals - min_vals
|
|
range_vals[range_vals == 0] = 1 # Avoid division by zero
|
|
ohlcv = (ohlcv - min_vals) / range_vals
|
|
|
|
# Pad if needed
|
|
padded = np.zeros((max_candles, 5), dtype=np.float32)
|
|
padded[-len(ohlcv):] = ohlcv
|
|
|
|
# Convert to tensor [batch, channels, sequence]
|
|
tensor = torch.FloatTensor(padded.transpose(1, 0)).unsqueeze(0).to(self.device)
|
|
return tensor
|
|
else:
|
|
return torch.zeros((1, 5, max_candles), device=self.device)
|
|
|
|
def learn(self):
|
|
"""Learn from a batch of experiences with GPU acceleration and CNN features"""
|
|
if len(self.memory) < BATCH_SIZE:
|
|
return None
|
|
|
|
try:
|
|
# Sample a batch of experiences
|
|
experiences = self.memory.sample(BATCH_SIZE)
|
|
|
|
# Convert experiences to tensors
|
|
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
|
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
|
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
|
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
|
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
|
|
|
|
# Use mixed precision for forward/backward passes if on GPU
|
|
if self.device.type == "cuda" and self.scaler is not None:
|
|
with torch.cuda.amp.autocast():
|
|
# Compute current Q values
|
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
|
|
|
# Compute next Q values
|
|
with torch.no_grad():
|
|
next_q_values = self.target_net(next_states).max(1)[0]
|
|
|
|
# Compute target Q values
|
|
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
|
|
target_q_values = target_q_values.unsqueeze(1)
|
|
|
|
# Compute loss
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
|
|
|
# Backward pass with gradient scaling
|
|
self.optimizer.zero_grad()
|
|
self.scaler.scale(loss).backward()
|
|
|
|
# Clip gradients
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
# Update weights
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
# Standard precision for CPU
|
|
# Compute Q values
|
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
|
|
|
# Compute next Q values with target network
|
|
with torch.no_grad():
|
|
next_q_values = self.target_net(next_states).max(1)[0]
|
|
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
|
|
|
|
# Reshape target values to match current_q_values
|
|
target_q_values = target_q_values.unsqueeze(1)
|
|
|
|
# Compute loss
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
|
|
|
# Backward pass
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# Gradient clipping to prevent exploding gradients
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
self.optimizer.step()
|
|
|
|
# Update steps done
|
|
self.steps_done += 1
|
|
|
|
# Update target network
|
|
if self.steps_done % TARGET_UPDATE == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
return loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during learning: {e}")
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
return None
|
|
|
|
def update_epsilon(self, episode):
|
|
"""Update epsilon value based on episode number"""
|
|
# Calculate epsilon using a linear decay formula
|
|
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
|
max(0, (self.epsilon_decay - episode)) / self.epsilon_decay
|
|
|
|
# Update self.epsilon with the calculated value
|
|
self.epsilon = max(self.epsilon_min, epsilon)
|
|
|
|
return self.epsilon
|
|
|
|
def update_target_network(self):
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
def save(self, path="models/trading_agent_best_pnl.pt"):
|
|
"""Save the model using a robust saving approach with multiple fallbacks"""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
# Call robust save function
|
|
success = robust_save(self, path)
|
|
|
|
if success:
|
|
logger.info(f"Model saved successfully to {path}")
|
|
return True
|
|
else:
|
|
logger.error(f"All save attempts failed for path: {path}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in save method: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
def load(self, path="models/trading_agent_best_pnl.pt"):
|
|
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
|
|
try:
|
|
# First try to load with weights_only=False (for models saved with older PyTorch versions)
|
|
try:
|
|
logger.info(f"Attempting to load model with weights_only=False: {path}")
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
logger.info("Model loaded successfully with weights_only=False")
|
|
except Exception as e1:
|
|
logger.warning(f"Failed to load with weights_only=False: {e1}")
|
|
|
|
# Try with safe_globals context manager
|
|
try:
|
|
logger.info("Attempting to load with safe_globals context manager")
|
|
import numpy as np
|
|
from torch.serialization import safe_globals
|
|
|
|
# Add numpy scalar to safe globals
|
|
with safe_globals(['numpy._core.multiarray.scalar']):
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
logger.info("Model loaded successfully with safe_globals")
|
|
except Exception as e2:
|
|
logger.warning(f"Failed to load with safe_globals: {e2}")
|
|
|
|
# Last resort: try with pickle_module=pickle
|
|
logger.info("Attempting to load with pickle_module")
|
|
import pickle
|
|
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
|
|
logger.info("Model loaded successfully with pickle_module")
|
|
|
|
# Load state dictionaries
|
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
|
|
|
# Try to load optimizer state
|
|
try:
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
except Exception as e:
|
|
logger.warning(f"Could not load optimizer state: {e}")
|
|
|
|
# Load epsilon if available
|
|
if 'epsilon' in checkpoint:
|
|
self.epsilon = checkpoint['epsilon']
|
|
|
|
# Load architecture parameters if available
|
|
if 'state_size' in checkpoint:
|
|
self.state_size = checkpoint['state_size']
|
|
if 'action_size' in checkpoint:
|
|
self.action_size = checkpoint['action_size']
|
|
if 'hidden_size' in checkpoint:
|
|
self.hidden_size = checkpoint['hidden_size']
|
|
else:
|
|
# If hidden_size not in checkpoint, infer from model
|
|
try:
|
|
self.hidden_size = self.policy_net.fc1.weight.shape[0]
|
|
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
|
|
except:
|
|
self.hidden_size = 256 # Default value
|
|
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
|
|
|
|
if 'lstm_layers' in checkpoint:
|
|
self.lstm_layers = checkpoint['lstm_layers']
|
|
else:
|
|
self.lstm_layers = 2 # Default value
|
|
|
|
if 'attention_heads' in checkpoint:
|
|
self.attention_heads = checkpoint['attention_heads']
|
|
else:
|
|
self.attention_heads = 4 # Default value
|
|
|
|
logger.info(f"Model loaded successfully from {path}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def add_chart_to_tensorboard(self, env, step):
|
|
"""Add candlestick chart to tensorboard and various metrics"""
|
|
try:
|
|
# Initialize writer if it doesn't exist
|
|
if not hasattr(self, 'writer') or self.writer is None:
|
|
self.writer = SummaryWriter(log_dir=f'runs/{self.model_name}')
|
|
|
|
# Check if we have enough data
|
|
if not hasattr(env, 'data') or len(env.data) < 20:
|
|
logger.warning("Not enough data for chart in TensorBoard")
|
|
return
|
|
|
|
# Get position value (convert from string if needed)
|
|
position_value = 0 # Default to flat
|
|
if hasattr(env, 'position'):
|
|
if isinstance(env.position, str):
|
|
# Map string positions to numeric values
|
|
position_map = {'flat': 0, 'long': 1, 'short': -1}
|
|
position_value = position_map.get(env.position.lower(), 0)
|
|
else:
|
|
position_value = float(env.position)
|
|
|
|
# Log metrics to tensorboard
|
|
self.writer.add_scalar('Trading/Position', position_value, step)
|
|
|
|
if hasattr(env, 'balance'):
|
|
self.writer.add_scalar('Trading/Balance', env.balance, step)
|
|
|
|
if hasattr(env, 'total_pnl'):
|
|
self.writer.add_scalar('Trading/Total_PnL', env.total_pnl, step)
|
|
|
|
if hasattr(env, 'max_drawdown'):
|
|
self.writer.add_scalar('Trading/Drawdown', env.max_drawdown, step)
|
|
|
|
if hasattr(env, 'win_rate'):
|
|
self.writer.add_scalar('Trading/Win_Rate', env.win_rate, step)
|
|
|
|
if hasattr(env, 'trade_count'):
|
|
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
|
|
|
|
# Log trading fees
|
|
if hasattr(env, 'total_fees'):
|
|
self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step)
|
|
# Also log net PnL (after fees)
|
|
if hasattr(env, 'total_pnl'):
|
|
self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step)
|
|
|
|
# Add candlestick chart if we have enough data
|
|
if len(env.data) >= 100:
|
|
try:
|
|
# Use the last 100 candles for the chart
|
|
recent_data = env.data[-100:]
|
|
|
|
# Get recent trades if available
|
|
recent_trades = None
|
|
if hasattr(env, 'trades') and len(env.trades) > 0:
|
|
recent_trades = env.trades[-10:] # Last 10 trades
|
|
|
|
# Create candlestick figure
|
|
fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}")
|
|
|
|
if fig:
|
|
# Add to tensorboard
|
|
self.writer.add_figure('Trading/Chart', fig, step)
|
|
|
|
# Close figure to free memory
|
|
plt.close(fig)
|
|
except Exception as e:
|
|
logger.warning(f"Error creating candlestick chart: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error in add_chart_to_tensorboard: {e}")
|
|
|
|
def select_action_realtime(self, state):
|
|
"""
|
|
Select action with minimal latency for real-time trading.
|
|
Optimized version of select_action for ultra-low latency requirements.
|
|
|
|
TODO: Implement optimized action selection for real-time trading
|
|
"""
|
|
# Convert to tensor if needed
|
|
state_tensor = torch.tensor(state, dtype=torch.float32)
|
|
|
|
# Fast forward pass through the network
|
|
with torch.no_grad():
|
|
q_values = self.policy_net.forward_realtime(state_tensor.unsqueeze(0))
|
|
|
|
# Get the action with highest Q-value
|
|
action = q_values.max(1)[1].item()
|
|
|
|
return action
|
|
|
|
def forward_realtime(self, state):
|
|
"""
|
|
Optimized forward pass for real-time trading with minimal latency.
|
|
|
|
TODO: Implement streamlined forward pass that prioritizes speed
|
|
"""
|
|
# For now, just use the regular forward pass
|
|
# This could be optimized later with techniques like:
|
|
# - Using a smaller model for real-time decisions
|
|
# - Skipping certain layers or calculations
|
|
# - Using quantized weights or other optimizations
|
|
|
|
return self.forward(state)
|
|
|
|
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
|
"""Get live price data using websockets"""
|
|
# Connect to MEXC websocket
|
|
uri = "wss://stream.mexc.com/ws"
|
|
|
|
async with websockets.connect(uri) as websocket:
|
|
# Subscribe to kline data
|
|
subscribe_msg = {
|
|
"method": "SUBSCRIPTION",
|
|
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
|
|
}
|
|
await websocket.send(json.dumps(subscribe_msg))
|
|
|
|
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
|
|
|
|
while True:
|
|
try:
|
|
response = await websocket.recv()
|
|
data = json.loads(response)
|
|
|
|
if 'data' in data:
|
|
kline = data['data']
|
|
candle = {
|
|
'timestamp': kline['t'],
|
|
'open': float(kline['o']),
|
|
'high': float(kline['h']),
|
|
'low': float(kline['l']),
|
|
'close': float(kline['c']),
|
|
'volume': float(kline['v'])
|
|
}
|
|
yield candle
|
|
|
|
except Exception as e:
|
|
logger.error(f"Websocket error: {e}")
|
|
# Try to reconnect
|
|
await asyncio.sleep(5)
|
|
break
|
|
|
|
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, use_compact_save=False):
|
|
"""
|
|
Train the agent in the environment.
|
|
|
|
Args:
|
|
agent: The agent to train
|
|
env: The trading environment
|
|
num_episodes: Number of episodes to train for
|
|
max_steps_per_episode: Maximum steps per episode
|
|
use_compact_save: Whether to use compact save (for low disk space)
|
|
|
|
Returns:
|
|
Training statistics
|
|
"""
|
|
# Initialize TensorBoard writer if not already done
|
|
try:
|
|
if agent.writer is None:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
agent.writer = SummaryWriter(log_dir=f'runs/{agent.model_name}')
|
|
|
|
writer = agent.writer
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize TensorBoard: {e}")
|
|
writer = None
|
|
|
|
# Initialize exchange for data fetching
|
|
try:
|
|
exchange = await initialize_exchange()
|
|
logging.info("Initialized exchange for data fetching")
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize exchange: {e}")
|
|
exchange = None
|
|
|
|
# Initialize statistics tracking
|
|
stats = {
|
|
'episode_rewards': [],
|
|
'episode_lengths': [],
|
|
'balances': [],
|
|
'win_rates': [],
|
|
'episode_pnls': [],
|
|
'cumulative_pnl': [],
|
|
'drawdowns': [],
|
|
'trade_counts': [],
|
|
'loss_values': [],
|
|
'fees': [], # Track fees
|
|
'net_pnl_after_fees': [] # Track net PnL after fees
|
|
}
|
|
|
|
# Track best models
|
|
best_reward = float('-inf')
|
|
best_pnl = float('-inf')
|
|
best_net_pnl = float('-inf') # Track best net PnL (after fees)
|
|
|
|
# Make directory for models if it doesn't exist
|
|
os.makedirs('models', exist_ok=True)
|
|
|
|
# Memory management function
|
|
def clean_memory():
|
|
"""Clean up memory to avoid memory leaks"""
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Start training loop
|
|
for episode in range(num_episodes):
|
|
try:
|
|
# Clean up memory before starting a new episode
|
|
clean_memory()
|
|
|
|
# Reset environment
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
episode_losses = []
|
|
|
|
# Fetch multi-timeframe data at the start of the episode
|
|
candle_data = None
|
|
if exchange:
|
|
try:
|
|
candle_data = await fetch_multi_timeframe_data(
|
|
exchange, "ETH/USDT", agent.candle_cache
|
|
)
|
|
# Update CNN patterns
|
|
env.update_cnn_patterns(candle_data)
|
|
logging.info(f"Fetched multi-timeframe data for episode {episode+1}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to fetch candle data: {e}")
|
|
|
|
# Track consecutive errors
|
|
consecutive_errors = 0
|
|
max_consecutive_errors = 5
|
|
|
|
# Episode loop
|
|
for step in range(max_steps_per_episode):
|
|
try:
|
|
# Select action using CNN-enhanced policy
|
|
action = agent.select_action(state, training=True, candle_data=candle_data)
|
|
|
|
# Take action
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store transition in replay memory
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
# Move to the next state
|
|
state = next_state
|
|
|
|
# Update episode reward
|
|
episode_reward += reward
|
|
|
|
# Learn from experience
|
|
if len(agent.memory) > BATCH_SIZE:
|
|
try:
|
|
loss = agent.learn()
|
|
if loss is not None:
|
|
episode_losses.append(loss)
|
|
# Log loss to TensorBoard
|
|
global_step = episode * max_steps_per_episode + step
|
|
if writer:
|
|
writer.add_scalar('Loss/step', loss, global_step)
|
|
|
|
# Reset consecutive errors counter on successful learning
|
|
consecutive_errors = 0
|
|
except Exception as e:
|
|
logging.error(f"Error during learning: {e}")
|
|
consecutive_errors += 1
|
|
if consecutive_errors >= max_consecutive_errors:
|
|
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
|
break
|
|
|
|
# Update target network periodically
|
|
if step % TARGET_UPDATE == 0:
|
|
agent.update_target_network()
|
|
|
|
# Update price predictions and CNN patterns periodically
|
|
if step % 50 == 0:
|
|
try:
|
|
# Update internal environment predictions
|
|
if hasattr(env, 'update_price_predictions'):
|
|
env.update_price_predictions()
|
|
if hasattr(env, 'identify_optimal_trades'):
|
|
env.identify_optimal_trades()
|
|
|
|
# Fetch fresh candle data periodically
|
|
if exchange:
|
|
try:
|
|
candle_data = await fetch_multi_timeframe_data(
|
|
exchange, "ETH/USDT", agent.candle_cache
|
|
)
|
|
|
|
# Update CNN patterns with the new candle data
|
|
env.update_cnn_patterns(candle_data)
|
|
logging.info(f"Updated multi-timeframe data at step {step}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to fetch candle data: {e}")
|
|
except Exception as e:
|
|
logging.warning(f"Error updating predictions: {e}")
|
|
|
|
# Clean memory periodically during long episodes
|
|
if step % 200 == 0 and step > 0:
|
|
clean_memory()
|
|
|
|
# Add chart to TensorBoard periodically
|
|
if step % 100 == 0 or (step == max_steps_per_episode - 1) or done:
|
|
try:
|
|
global_step = episode * max_steps_per_episode + step
|
|
if writer:
|
|
agent.add_chart_to_tensorboard(env, global_step)
|
|
except Exception as e:
|
|
logging.warning(f"Error adding chart to TensorBoard: {e}")
|
|
|
|
if done:
|
|
break
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in training step: {e}")
|
|
consecutive_errors += 1
|
|
if consecutive_errors >= max_consecutive_errors:
|
|
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
|
break
|
|
|
|
# Calculate statistics from this episode
|
|
balance = env.balance
|
|
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
|
|
fees = env.total_fees if hasattr(env, 'total_fees') else 0
|
|
net_pnl = pnl - fees # Calculate net PnL after fees
|
|
|
|
# Get trading statistics
|
|
trade_analysis = None
|
|
if hasattr(env, 'analyze_trades'):
|
|
trade_analysis = env.analyze_trades()
|
|
|
|
win_rate = trade_analysis['win_rate'] if trade_analysis and 'win_rate' in trade_analysis else 0
|
|
trade_count = trade_analysis['total_trades'] if trade_analysis and 'total_trades' in trade_analysis else 0
|
|
max_drawdown = trade_analysis['max_drawdown'] if trade_analysis and 'max_drawdown' in trade_analysis else 0
|
|
|
|
# Calculate average loss for this episode
|
|
avg_loss = sum(episode_losses) / len(episode_losses) if episode_losses else 0
|
|
|
|
# Log episode metrics to TensorBoard
|
|
if writer:
|
|
writer.add_scalar('Reward/episode', episode_reward, episode)
|
|
writer.add_scalar('Balance/episode', balance, episode)
|
|
writer.add_scalar('PnL/episode', pnl, episode)
|
|
writer.add_scalar('NetPnL/episode', net_pnl, episode)
|
|
writer.add_scalar('Fees/episode', fees, episode)
|
|
writer.add_scalar('WinRate/episode', win_rate, episode)
|
|
writer.add_scalar('TradeCount/episode', trade_count, episode)
|
|
writer.add_scalar('Drawdown/episode', max_drawdown, episode)
|
|
writer.add_scalar('Loss/episode', avg_loss, episode)
|
|
writer.add_scalar('Epsilon/episode', agent.epsilon, episode)
|
|
|
|
# Update stats dictionary
|
|
stats['episode_rewards'].append(episode_reward)
|
|
stats['episode_lengths'].append(step + 1)
|
|
stats['balances'].append(balance)
|
|
stats['win_rates'].append(win_rate)
|
|
stats['episode_pnls'].append(pnl)
|
|
stats['drawdowns'].append(max_drawdown)
|
|
stats['trade_counts'].append(trade_count)
|
|
stats['loss_values'].append(avg_loss)
|
|
stats['fees'].append(fees)
|
|
stats['net_pnl_after_fees'].append(net_pnl)
|
|
|
|
# Calculate and update cumulative PnL
|
|
if len(stats['episode_pnls']) > 0:
|
|
cumulative_pnl = sum(stats['episode_pnls'])
|
|
if 'cumulative_pnl' not in stats:
|
|
stats['cumulative_pnl'] = []
|
|
stats['cumulative_pnl'].append(cumulative_pnl)
|
|
if writer:
|
|
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
|
|
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
|
|
|
|
# Save model if this is the best reward or PnL
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
try:
|
|
if use_compact_save:
|
|
success = compact_save(agent, 'models/trading_agent_best_reward.pt')
|
|
else:
|
|
success = agent.save('models/trading_agent_best_reward.pt')
|
|
if success:
|
|
logging.info(f"New best reward: {best_reward:.2f}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving best reward model: {e}")
|
|
|
|
if pnl > best_pnl:
|
|
best_pnl = pnl
|
|
try:
|
|
if use_compact_save:
|
|
success = compact_save(agent, 'models/trading_agent_best_pnl.pt')
|
|
else:
|
|
success = agent.save('models/trading_agent_best_pnl.pt')
|
|
if success:
|
|
logging.info(f"New best PnL: ${best_pnl:.2f}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving best PnL model: {e}")
|
|
|
|
# Save model if this is the best net PnL (after fees)
|
|
if net_pnl > best_net_pnl:
|
|
best_net_pnl = net_pnl
|
|
try:
|
|
if use_compact_save:
|
|
success = compact_save(agent, 'models/trading_agent_best_net_pnl.pt')
|
|
else:
|
|
success = agent.save('models/trading_agent_best_net_pnl.pt')
|
|
if success:
|
|
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving best net PnL model: {e}")
|
|
|
|
# Save checkpoint periodically
|
|
if episode % 10 == 0:
|
|
try:
|
|
if use_compact_save:
|
|
compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt')
|
|
else:
|
|
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
|
|
except Exception as e:
|
|
logging.error(f"Error saving checkpoint model: {e}")
|
|
|
|
# Update epsilon
|
|
agent.update_epsilon(episode)
|
|
|
|
# Log training progress
|
|
logging.info(f"Episode {episode+1}/{num_episodes} | " +
|
|
f"Reward: {episode_reward:.2f} | " +
|
|
f"Balance: ${balance:.2f} | " +
|
|
f"PnL: ${pnl:.2f} | " +
|
|
f"Fees: ${fees:.2f} | " +
|
|
f"Net PnL: ${net_pnl:.2f} | " +
|
|
f"Win Rate: {win_rate:.2f} | " +
|
|
f"Trades: {trade_count} | " +
|
|
f"Loss: {avg_loss:.5f} | " +
|
|
f"Epsilon: {agent.epsilon:.4f}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in episode {episode}: {e}")
|
|
logging.error(traceback.format_exc())
|
|
continue
|
|
|
|
# Clean memory before saving final model
|
|
clean_memory()
|
|
|
|
# Save final model
|
|
try:
|
|
if use_compact_save:
|
|
compact_save(agent, 'models/trading_agent_final.pt')
|
|
else:
|
|
agent.save('models/trading_agent_final.pt')
|
|
except Exception as e:
|
|
logging.error(f"Error saving final model: {e}")
|
|
|
|
# Save training statistics to file
|
|
try:
|
|
import pandas as pd
|
|
|
|
# Make sure all arrays in stats are the same length by padding with NaN
|
|
max_length = max(len(v) for k, v in stats.items() if isinstance(v, list))
|
|
for k, v in stats.items():
|
|
if isinstance(v, list) and len(v) < max_length:
|
|
stats[k] = v + [float('nan')] * (max_length - len(v))
|
|
|
|
# Create dataframe and save
|
|
stats_df = pd.DataFrame(stats)
|
|
stats_df.to_csv('training_stats.csv', index=False)
|
|
logging.info(f"Training statistics saved to training_stats.csv")
|
|
except Exception as e:
|
|
logging.error(f"Failed to save training statistics: {e}")
|
|
logging.error(traceback.format_exc())
|
|
|
|
# Close exchange if it's still open
|
|
if exchange:
|
|
try:
|
|
# Check if exchange has the close method (ccxt.async_support)
|
|
if hasattr(exchange, 'close'):
|
|
await exchange.close()
|
|
logging.info("Closed exchange connection")
|
|
else:
|
|
logging.info("Exchange doesn't have close method (standard ccxt), skipping close")
|
|
except Exception as e:
|
|
logging.error(f"Error closing exchange: {e}")
|
|
|
|
return stats
|
|
|
|
def plot_training_results(stats):
|
|
"""Plot training results and save to file"""
|
|
try:
|
|
# Check if we have data to plot
|
|
if not stats or len(stats.get('episode_rewards', [])) == 0:
|
|
logger.warning("No training data to plot")
|
|
return
|
|
|
|
# Create a DataFrame with consistent lengths
|
|
max_len = max(len(stats.get(key, [])) for key in stats)
|
|
|
|
# Ensure all arrays have the same length by padding with the last value or zeros
|
|
processed_stats = {}
|
|
for key, values in stats.items():
|
|
if not values: # Skip empty lists
|
|
continue
|
|
|
|
# Pad arrays to the same length
|
|
if len(values) < max_len:
|
|
if len(values) > 0:
|
|
# Pad with the last value
|
|
values = values + [values[-1]] * (max_len - len(values))
|
|
else:
|
|
# Pad with zeros
|
|
values = [0] * max_len
|
|
|
|
processed_stats[key] = values[:max_len] # Trim if longer
|
|
|
|
# Create DataFrame
|
|
df = pd.DataFrame(processed_stats)
|
|
|
|
# Add episode column
|
|
df['episode'] = range(1, len(df) + 1)
|
|
|
|
# Create figure with subplots
|
|
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
|
|
|
|
# Plot episode rewards
|
|
if 'episode_rewards' in df.columns:
|
|
axes[0, 0].plot(df['episode'], df['episode_rewards'])
|
|
axes[0, 0].set_title('Episode Rewards')
|
|
axes[0, 0].set_xlabel('Episode')
|
|
axes[0, 0].set_ylabel('Reward')
|
|
axes[0, 0].grid(True)
|
|
|
|
# Plot account balance
|
|
if 'balances' in df.columns:
|
|
axes[0, 1].plot(df['episode'], df['balances'])
|
|
axes[0, 1].set_title('Account Balance')
|
|
axes[0, 1].set_xlabel('Episode')
|
|
axes[0, 1].set_ylabel('Balance ($)')
|
|
axes[0, 1].grid(True)
|
|
|
|
# Plot win rate
|
|
if 'win_rates' in df.columns:
|
|
axes[1, 0].plot(df['episode'], df['win_rates'])
|
|
axes[1, 0].set_title('Win Rate')
|
|
axes[1, 0].set_xlabel('Episode')
|
|
axes[1, 0].set_ylabel('Win Rate')
|
|
axes[1, 0].set_ylim([0, 1])
|
|
axes[1, 0].grid(True)
|
|
|
|
# Plot episode PnL
|
|
if 'episode_pnls' in df.columns:
|
|
axes[1, 1].plot(df['episode'], df['episode_pnls'])
|
|
axes[1, 1].set_title('Episode PnL')
|
|
axes[1, 1].set_xlabel('Episode')
|
|
axes[1, 1].set_ylabel('PnL ($)')
|
|
axes[1, 1].grid(True)
|
|
|
|
# Plot cumulative PnL
|
|
if 'cumulative_pnl' in df.columns:
|
|
axes[2, 0].plot(df['episode'], df['cumulative_pnl'])
|
|
axes[2, 0].set_title('Cumulative PnL')
|
|
axes[2, 0].set_xlabel('Episode')
|
|
axes[2, 0].set_ylabel('Cumulative PnL ($)')
|
|
axes[2, 0].grid(True)
|
|
|
|
# Plot maximum drawdown
|
|
if 'drawdowns' in df.columns:
|
|
axes[2, 1].plot(df['episode'], df['drawdowns'])
|
|
axes[2, 1].set_title('Maximum Drawdown')
|
|
axes[2, 1].set_xlabel('Episode')
|
|
axes[2, 1].set_ylabel('Drawdown')
|
|
axes[2, 1].grid(True)
|
|
|
|
# Adjust layout
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
plt.savefig('training_results.png')
|
|
logger.info("Training results saved to training_results.png")
|
|
|
|
# Save statistics to CSV
|
|
df.to_csv('training_stats.csv', index=False)
|
|
logger.info("Training statistics saved to training_stats.csv")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error plotting training results: {e}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
def evaluate_agent(agent, env, num_episodes=10):
|
|
"""Evaluate the agent on test data"""
|
|
total_reward = 0
|
|
total_profit = 0
|
|
total_trades = 0
|
|
winning_trades = 0
|
|
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
initial_balance = env.balance
|
|
|
|
done = False
|
|
while not done:
|
|
# Select action (no exploration)
|
|
action = agent.select_action(state, training=False)
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
|
|
total_reward += episode_reward
|
|
total_profit += env.balance - initial_balance
|
|
|
|
# Count trades and wins
|
|
for trade in env.trades:
|
|
if 'pnl_percent' in trade:
|
|
total_trades += 1
|
|
if trade['pnl_percent'] > 0:
|
|
winning_trades += 1
|
|
|
|
# Calculate averages
|
|
avg_reward = total_reward / num_episodes
|
|
avg_profit = total_profit / num_episodes
|
|
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
|
|
|
|
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
|
|
f"Win Rate={win_rate:.1f}%")
|
|
|
|
return avg_reward, avg_profit, win_rate
|
|
|
|
async def test_training():
|
|
"""Test the training process with a small number of episodes"""
|
|
logger.info("Starting training tests...")
|
|
|
|
# Initialize exchange
|
|
exchange = ccxt.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True,
|
|
})
|
|
|
|
try:
|
|
# Create environment with small initial balance for testing
|
|
env = TradingEnvironment(
|
|
exchange=exchange,
|
|
symbol="ETH/USDT",
|
|
timeframe="1m",
|
|
leverage=MAX_LEVERAGE,
|
|
initial_balance=100, # Small balance for testing
|
|
demo=True # Always use demo mode for testing
|
|
)
|
|
|
|
# Fetch initial data
|
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
|
|
|
# Create agent
|
|
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
|
|
|
|
# Run a few test episodes
|
|
test_episodes = 3
|
|
logger.info(f"Running {test_episodes} test episodes...")
|
|
|
|
for episode in range(test_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
done = False
|
|
step = 0
|
|
|
|
while not done and step < 100: # Limit steps for testing
|
|
# Select action
|
|
action = agent.select_action(state)
|
|
|
|
# Take action
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store experience
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
# Learn
|
|
loss = agent.learn()
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
step += 1
|
|
|
|
# Print progress
|
|
if step % 10 == 0:
|
|
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
|
|
|
|
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
|
|
|
|
# Test model saving
|
|
try:
|
|
agent.save("models/test_model.pt")
|
|
logger.info("Successfully saved model")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
|
|
logger.info("Training tests completed successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training test failed: {e}")
|
|
return False
|
|
|
|
finally:
|
|
await exchange.close()
|
|
|
|
async def initialize_exchange():
|
|
"""Initialize the exchange connection"""
|
|
try:
|
|
# Try to initialize with async support first
|
|
try:
|
|
exchange = ccxt.pro.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True
|
|
})
|
|
logger.info(f"Exchange initialized with async support: {exchange.id}")
|
|
except (AttributeError, ImportError):
|
|
# Fall back to standard CCXT
|
|
exchange = ccxt.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True
|
|
})
|
|
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
|
|
|
|
return exchange
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize exchange: {e}")
|
|
raise
|
|
|
|
async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""Fetch historical OHLCV data from the exchange"""
|
|
try:
|
|
logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}")
|
|
|
|
# Use the refactored fetch method
|
|
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
|
|
|
|
if not data:
|
|
logger.warning("No historical data received")
|
|
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch historical data: {e}")
|
|
return []
|
|
|
|
async def live_trading(
|
|
symbol="ETH/USDT",
|
|
timeframe="1m",
|
|
model_path=None,
|
|
demo=False,
|
|
leverage=50,
|
|
initial_balance=1000,
|
|
max_position_size=0.1,
|
|
commission=0.0004,
|
|
window_size=30,
|
|
update_interval=60,
|
|
stop_loss_pct=0.02,
|
|
take_profit_pct=0.04,
|
|
max_trades_per_day=10,
|
|
risk_per_trade=0.02,
|
|
use_trailing_stop=False,
|
|
trailing_stop_callback=0.005,
|
|
use_dynamic_sizing=True,
|
|
use_volatility_sizing=True,
|
|
use_multi_timeframe=True,
|
|
use_sentiment=False,
|
|
use_limit_orders=False,
|
|
use_dollar_cost_avg=False,
|
|
use_grid_trading=False,
|
|
use_martingale=False,
|
|
use_anti_martingale=False,
|
|
use_custom_indicators=True,
|
|
use_ml_predictions=True,
|
|
use_ensemble=True,
|
|
use_reinforcement=True,
|
|
use_risk_management=True,
|
|
use_portfolio_management=False,
|
|
use_position_sizing=True,
|
|
use_stop_loss=True,
|
|
use_take_profit=True,
|
|
use_trailing_stop_loss=False,
|
|
use_dynamic_stop_loss=True,
|
|
use_dynamic_take_profit=True,
|
|
use_dynamic_trailing_stop=False,
|
|
use_dynamic_position_sizing=True,
|
|
use_dynamic_leverage=False,
|
|
use_dynamic_risk_per_trade=True,
|
|
use_dynamic_max_trades_per_day=False,
|
|
use_dynamic_update_interval=False,
|
|
use_dynamic_window_size=False,
|
|
use_dynamic_commission=False,
|
|
use_dynamic_timeframe=False,
|
|
use_dynamic_symbol=False,
|
|
use_dynamic_model_path=False,
|
|
use_dynamic_demo=False,
|
|
use_dynamic_leverage_value=False,
|
|
use_dynamic_initial_balance=False,
|
|
use_dynamic_max_position_size=False,
|
|
use_dynamic_stop_loss_pct=False,
|
|
use_dynamic_take_profit_pct=False,
|
|
use_dynamic_risk_per_trade_value=False,
|
|
use_dynamic_trailing_stop_callback=False,
|
|
use_dynamic_use_trailing_stop=False,
|
|
use_dynamic_use_dynamic_sizing=False,
|
|
use_dynamic_use_volatility_sizing=False,
|
|
use_dynamic_use_multi_timeframe=False,
|
|
use_dynamic_use_sentiment=False,
|
|
use_dynamic_use_limit_orders=False,
|
|
use_dynamic_use_dollar_cost_avg=False,
|
|
use_dynamic_use_grid_trading=False,
|
|
use_dynamic_use_martingale=False,
|
|
use_dynamic_use_anti_martingale=False,
|
|
use_dynamic_use_custom_indicators=False,
|
|
use_dynamic_use_ml_predictions=False,
|
|
use_dynamic_use_ensemble=False,
|
|
use_dynamic_use_reinforcement=False,
|
|
use_dynamic_use_risk_management=False,
|
|
use_dynamic_use_portfolio_management=False,
|
|
use_dynamic_use_position_sizing=False,
|
|
use_dynamic_use_stop_loss=False,
|
|
use_dynamic_use_take_profit=False,
|
|
use_dynamic_use_trailing_stop_loss=False,
|
|
use_dynamic_use_dynamic_stop_loss=False,
|
|
use_dynamic_use_dynamic_take_profit=False,
|
|
use_dynamic_use_dynamic_trailing_stop=False,
|
|
use_dynamic_use_dynamic_position_sizing=False,
|
|
use_dynamic_use_dynamic_leverage=False,
|
|
use_dynamic_use_dynamic_risk_per_trade=False,
|
|
use_dynamic_use_dynamic_max_trades_per_day=False,
|
|
use_dynamic_use_dynamic_update_interval=False,
|
|
use_dynamic_use_dynamic_window_size=False,
|
|
use_dynamic_use_dynamic_commission=False,
|
|
use_dynamic_use_dynamic_timeframe=False,
|
|
use_dynamic_use_dynamic_symbol=False,
|
|
use_dynamic_use_dynamic_model_path=False,
|
|
use_dynamic_use_dynamic_demo=False,
|
|
use_dynamic_use_dynamic_leverage_value=False,
|
|
use_dynamic_use_dynamic_initial_balance=False,
|
|
use_dynamic_use_dynamic_max_position_size=False,
|
|
use_dynamic_use_dynamic_stop_loss_pct=False,
|
|
use_dynamic_use_dynamic_take_profit_pct=False,
|
|
use_dynamic_use_dynamic_risk_per_trade_value=False,
|
|
use_dynamic_use_dynamic_trailing_stop_callback=False,
|
|
):
|
|
"""
|
|
Live trading function that connects to the exchange and trades in real-time.
|
|
|
|
Args:
|
|
symbol: Trading pair symbol
|
|
timeframe: Timeframe for trading
|
|
model_path: Path to the trained model
|
|
demo: Whether to use demo mode (sandbox)
|
|
leverage: Leverage to use
|
|
initial_balance: Initial balance
|
|
max_position_size: Maximum position size as a percentage of balance
|
|
commission: Commission rate
|
|
window_size: Window size for the environment
|
|
update_interval: Interval to update data in seconds
|
|
stop_loss_pct: Stop loss percentage
|
|
take_profit_pct: Take profit percentage
|
|
max_trades_per_day: Maximum trades per day
|
|
risk_per_trade: Risk per trade as a percentage of balance
|
|
use_trailing_stop: Whether to use trailing stop
|
|
trailing_stop_callback: Trailing stop callback percentage
|
|
use_dynamic_sizing: Whether to use dynamic position sizing
|
|
use_volatility_sizing: Whether to use volatility-based position sizing
|
|
use_multi_timeframe: Whether to use multi-timeframe analysis
|
|
use_sentiment: Whether to use sentiment analysis
|
|
use_limit_orders: Whether to use limit orders
|
|
use_dollar_cost_avg: Whether to use dollar cost averaging
|
|
use_grid_trading: Whether to use grid trading
|
|
use_martingale: Whether to use martingale strategy
|
|
use_anti_martingale: Whether to use anti-martingale strategy
|
|
use_custom_indicators: Whether to use custom indicators
|
|
use_ml_predictions: Whether to use ML predictions
|
|
use_ensemble: Whether to use ensemble methods
|
|
use_reinforcement: Whether to use reinforcement learning
|
|
use_risk_management: Whether to use risk management
|
|
use_portfolio_management: Whether to use portfolio management
|
|
use_position_sizing: Whether to use position sizing
|
|
use_stop_loss: Whether to use stop loss
|
|
use_take_profit: Whether to use take profit
|
|
use_trailing_stop_loss: Whether to use trailing stop loss
|
|
use_dynamic_stop_loss: Whether to use dynamic stop loss
|
|
use_dynamic_take_profit: Whether to use dynamic take profit
|
|
use_dynamic_trailing_stop: Whether to use dynamic trailing stop
|
|
use_dynamic_position_sizing: Whether to use dynamic position sizing
|
|
use_dynamic_leverage: Whether to use dynamic leverage
|
|
use_dynamic_risk_per_trade: Whether to use dynamic risk per trade
|
|
use_dynamic_max_trades_per_day: Whether to use dynamic max trades per day
|
|
use_dynamic_update_interval: Whether to use dynamic update interval
|
|
use_dynamic_window_size: Whether to use dynamic window size
|
|
use_dynamic_commission: Whether to use dynamic commission
|
|
use_dynamic_timeframe: Whether to use dynamic timeframe
|
|
use_dynamic_symbol: Whether to use dynamic symbol
|
|
use_dynamic_model_path: Whether to use dynamic model path
|
|
use_dynamic_demo: Whether to use dynamic demo
|
|
use_dynamic_leverage_value: Whether to use dynamic leverage value
|
|
use_dynamic_initial_balance: Whether to use dynamic initial balance
|
|
use_dynamic_max_position_size: Whether to use dynamic max position size
|
|
use_dynamic_stop_loss_pct: Whether to use dynamic stop loss percentage
|
|
use_dynamic_take_profit_pct: Whether to use dynamic take profit percentage
|
|
use_dynamic_risk_per_trade_value: Whether to use dynamic risk per trade value
|
|
use_dynamic_trailing_stop_callback: Whether to use dynamic trailing stop callback
|
|
"""
|
|
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
|
|
logger.info(f"Demo mode: {demo}, Leverage: {leverage}x")
|
|
|
|
# Flag to track if we're using mock trading
|
|
using_mock_trading = False
|
|
|
|
# Initialize exchange
|
|
try:
|
|
exchange = await initialize_exchange()
|
|
|
|
# Try to set sandbox mode if demo is True
|
|
if demo:
|
|
try:
|
|
exchange.set_sandbox_mode(demo)
|
|
logger.info(f"Sandbox mode set to {demo}")
|
|
except Exception as e:
|
|
logger.warning(f"Exchange doesn't support sandbox mode: {e}")
|
|
logger.info("Continuing in mock trading mode instead")
|
|
using_mock_trading = True
|
|
|
|
# Set leverage
|
|
if not demo or using_mock_trading:
|
|
try:
|
|
await exchange.set_leverage(leverage, symbol)
|
|
logger.info(f"Leverage set to {leverage}x")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to set leverage: {e}")
|
|
|
|
# Initialize environment
|
|
env = TradingEnvironment(
|
|
initial_balance=initial_balance,
|
|
leverage=leverage,
|
|
window_size=window_size,
|
|
commission=commission,
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
max_position_size=max_position_size,
|
|
stop_loss_pct=stop_loss_pct,
|
|
take_profit_pct=take_profit_pct,
|
|
max_trades_per_day=max_trades_per_day,
|
|
risk_per_trade=risk_per_trade,
|
|
use_trailing_stop=use_trailing_stop,
|
|
trailing_stop_callback=trailing_stop_callback,
|
|
use_dynamic_sizing=use_dynamic_sizing,
|
|
use_volatility_sizing=use_volatility_sizing,
|
|
use_multi_timeframe=use_multi_timeframe,
|
|
use_sentiment=use_sentiment,
|
|
use_limit_orders=use_limit_orders,
|
|
use_dollar_cost_avg=use_dollar_cost_avg,
|
|
use_grid_trading=use_grid_trading,
|
|
use_martingale=use_martingale,
|
|
use_anti_martingale=use_anti_martingale,
|
|
use_custom_indicators=use_custom_indicators,
|
|
use_ml_predictions=use_ml_predictions,
|
|
use_ensemble=use_ensemble,
|
|
use_reinforcement=use_reinforcement,
|
|
use_risk_management=use_risk_management,
|
|
use_portfolio_management=use_portfolio_management,
|
|
use_position_sizing=use_position_sizing,
|
|
use_stop_loss=use_stop_loss,
|
|
use_take_profit=use_take_profit,
|
|
use_trailing_stop_loss=use_trailing_stop_loss,
|
|
use_dynamic_stop_loss=use_dynamic_stop_loss,
|
|
use_dynamic_take_profit=use_dynamic_take_profit,
|
|
use_dynamic_trailing_stop=use_dynamic_trailing_stop,
|
|
use_dynamic_position_sizing=use_dynamic_position_sizing,
|
|
use_dynamic_leverage=use_dynamic_leverage,
|
|
use_dynamic_risk_per_trade=use_dynamic_risk_per_trade,
|
|
use_dynamic_max_trades_per_day=use_dynamic_max_trades_per_day,
|
|
use_dynamic_update_interval=use_dynamic_update_interval,
|
|
use_dynamic_window_size=use_dynamic_window_size,
|
|
use_dynamic_commission=use_dynamic_commission,
|
|
use_dynamic_timeframe=use_dynamic_timeframe,
|
|
use_dynamic_symbol=use_dynamic_symbol,
|
|
use_dynamic_model_path=use_dynamic_model_path,
|
|
use_dynamic_demo=use_dynamic_demo,
|
|
use_dynamic_leverage_value=use_dynamic_leverage_value,
|
|
use_dynamic_initial_balance=use_dynamic_initial_balance,
|
|
use_dynamic_max_position_size=use_dynamic_max_position_size,
|
|
use_dynamic_stop_loss_pct=use_dynamic_stop_loss_pct,
|
|
use_dynamic_take_profit_pct=use_dynamic_take_profit_pct,
|
|
use_dynamic_risk_per_trade_value=use_dynamic_risk_per_trade_value,
|
|
use_dynamic_trailing_stop_callback=use_dynamic_trailing_stop_callback,
|
|
use_dynamic_use_trailing_stop=use_dynamic_use_trailing_stop,
|
|
use_dynamic_use_dynamic_sizing=use_dynamic_use_dynamic_sizing,
|
|
use_dynamic_use_volatility_sizing=use_dynamic_use_volatility_sizing,
|
|
use_dynamic_use_multi_timeframe=use_dynamic_use_multi_timeframe,
|
|
use_dynamic_use_sentiment=use_dynamic_use_sentiment,
|
|
use_dynamic_use_limit_orders=use_dynamic_use_limit_orders,
|
|
use_dynamic_use_dollar_cost_avg=use_dynamic_use_dollar_cost_avg,
|
|
use_dynamic_use_grid_trading=use_dynamic_use_grid_trading,
|
|
use_dynamic_use_martingale=use_dynamic_use_martingale,
|
|
use_dynamic_use_anti_martingale=use_dynamic_use_anti_martingale,
|
|
use_dynamic_use_custom_indicators=use_dynamic_use_custom_indicators,
|
|
use_dynamic_use_ml_predictions=use_dynamic_use_ml_predictions,
|
|
use_dynamic_use_ensemble=use_dynamic_use_ensemble,
|
|
use_dynamic_use_reinforcement=use_dynamic_use_reinforcement,
|
|
use_dynamic_use_risk_management=use_dynamic_use_risk_management,
|
|
use_dynamic_use_portfolio_management=use_dynamic_use_portfolio_management,
|
|
use_dynamic_use_position_sizing=use_dynamic_use_position_sizing,
|
|
use_dynamic_use_stop_loss=use_dynamic_use_stop_loss,
|
|
use_dynamic_use_take_profit=use_dynamic_use_take_profit,
|
|
use_dynamic_use_trailing_stop_loss=use_dynamic_use_trailing_stop_loss,
|
|
use_dynamic_use_dynamic_stop_loss=use_dynamic_use_dynamic_stop_loss,
|
|
use_dynamic_use_dynamic_take_profit=use_dynamic_use_dynamic_take_profit,
|
|
use_dynamic_use_dynamic_trailing_stop=use_dynamic_use_dynamic_trailing_stop,
|
|
use_dynamic_use_dynamic_position_sizing=use_dynamic_use_dynamic_position_sizing,
|
|
use_dynamic_use_dynamic_leverage=use_dynamic_use_dynamic_leverage,
|
|
use_dynamic_use_dynamic_risk_per_trade=use_dynamic_use_dynamic_risk_per_trade,
|
|
use_dynamic_use_dynamic_max_trades_per_day=use_dynamic_use_dynamic_max_trades_per_day,
|
|
use_dynamic_use_dynamic_update_interval=use_dynamic_use_dynamic_update_interval,
|
|
use_dynamic_use_dynamic_window_size=use_dynamic_use_dynamic_window_size,
|
|
use_dynamic_use_dynamic_commission=use_dynamic_use_dynamic_commission,
|
|
use_dynamic_use_dynamic_timeframe=use_dynamic_use_dynamic_timeframe,
|
|
use_dynamic_use_dynamic_symbol=use_dynamic_use_dynamic_symbol,
|
|
use_dynamic_use_dynamic_model_path=use_dynamic_use_dynamic_model_path,
|
|
use_dynamic_use_dynamic_demo=use_dynamic_use_dynamic_demo,
|
|
use_dynamic_use_dynamic_leverage_value=use_dynamic_use_dynamic_leverage_value,
|
|
use_dynamic_use_dynamic_initial_balance=use_dynamic_use_dynamic_initial_balance,
|
|
use_dynamic_use_dynamic_max_position_size=use_dynamic_use_dynamic_max_position_size,
|
|
use_dynamic_use_dynamic_stop_loss_pct=use_dynamic_use_dynamic_stop_loss_pct,
|
|
use_dynamic_use_dynamic_take_profit_pct=use_dynamic_use_dynamic_take_profit_pct,
|
|
use_dynamic_use_dynamic_risk_per_trade_value=use_dynamic_use_dynamic_risk_per_trade_value,
|
|
use_dynamic_use_dynamic_trailing_stop_callback=use_dynamic_use_dynamic_trailing_stop_callback,
|
|
)
|
|
|
|
# Fetch initial data
|
|
logger.info(f"Fetching initial data for {symbol}")
|
|
await fetch_and_update_data(exchange, env, symbol, timeframe)
|
|
|
|
# Initialize agent
|
|
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
|
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
|
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
|
|
|
|
# Load model if provided
|
|
if model_path:
|
|
agent.load(model_path)
|
|
logger.info(f"Model loaded successfully from {model_path}")
|
|
|
|
# Initialize TensorBoard writer
|
|
agent.writer = SummaryWriter(log_dir=f"runs/live_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
|
|
# Initialize trading statistics
|
|
trades = []
|
|
total_pnl = 0
|
|
win_count = 0
|
|
loss_count = 0
|
|
|
|
# Initialize trading log file
|
|
log_file = f"live_trading_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
|
with open(log_file, 'w') as f:
|
|
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
|
|
|
# Start live trading loop
|
|
logger.info(f"Starting live trading with {symbol} on {timeframe} timeframe")
|
|
|
|
# Main trading loop
|
|
step_counter = 0
|
|
last_update_time = time.time()
|
|
|
|
while True:
|
|
# Get current state
|
|
state = env.get_state()
|
|
|
|
# Select action
|
|
action = agent.select_action(state, training=False)
|
|
|
|
# Take action
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Log action and results
|
|
if info.get('trade_executed', False):
|
|
trade_data = {
|
|
'timestamp': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
|
'action': info['action'],
|
|
'price': env.current_price,
|
|
'position_size': env.position_size,
|
|
'balance': env.balance,
|
|
'pnl': env.last_trade_profit
|
|
}
|
|
|
|
trades.append(trade_data)
|
|
|
|
# Update statistics
|
|
if env.last_trade_profit > 0:
|
|
win_count += 1
|
|
total_pnl += env.last_trade_profit
|
|
else:
|
|
loss_count += 1
|
|
|
|
# Log trade to file
|
|
with open(log_file, 'a') as f:
|
|
f.write(f"{trade_data['timestamp']},{trade_data['action']},{trade_data['price']},{trade_data['position_size']},{trade_data['balance']},{trade_data['pnl']}\n")
|
|
|
|
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
|
|
|
# Update TensorBoard metrics
|
|
if step_counter % 10 == 0:
|
|
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
|
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
|
agent.writer.add_scalar('Live/Reward', reward, step_counter)
|
|
|
|
# Check if it's time to update data
|
|
current_time = time.time()
|
|
if current_time - last_update_time > update_interval:
|
|
await fetch_and_update_data(exchange, env, symbol, timeframe)
|
|
last_update_time = current_time
|
|
|
|
# Print status update
|
|
win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0
|
|
logger.info(f"""
|
|
Step: {step_counter}
|
|
Balance: ${env.balance:.2f}
|
|
Total PnL: ${env.total_pnl:.2f}
|
|
Win Rate: {win_rate:.2f}
|
|
Trades: {len(trades)}
|
|
""")
|
|
|
|
# Move to next state
|
|
state = next_state
|
|
step_counter += 1
|
|
|
|
# Sleep to avoid excessive API calls
|
|
await asyncio.sleep(1)
|
|
|
|
# Check for manual stop
|
|
if done:
|
|
break
|
|
|
|
# Close TensorBoard writer
|
|
agent.writer.close()
|
|
|
|
# Save final statistics
|
|
win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0
|
|
logger.info(f"""
|
|
Live Trading Summary:
|
|
Total Steps: {step_counter}
|
|
Final Balance: ${env.balance:.2f}
|
|
Total PnL: ${env.total_pnl:.2f}
|
|
Win Rate: {win_rate:.2f}
|
|
Total Trades: {len(trades)}
|
|
""")
|
|
|
|
# Close exchange connection
|
|
try:
|
|
await exchange.close()
|
|
logger.info("Exchange connection closed")
|
|
except Exception as e:
|
|
logger.warning(f"Error closing exchange connection: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in live trading: {e}")
|
|
logger.error(traceback.format_exc())
|
|
try:
|
|
await exchange.close()
|
|
except:
|
|
pass
|
|
logger.info("Exchange connection closed")
|
|
|
|
async def get_latest_candle(exchange, symbol):
|
|
"""
|
|
Get the latest candle for a symbol.
|
|
|
|
Args:
|
|
exchange: Exchange instance
|
|
symbol: Trading pair symbol
|
|
|
|
Returns:
|
|
Latest candle data or None on failure
|
|
"""
|
|
try:
|
|
# Use the refactored fetch method with limit=1
|
|
data = await fetch_ohlcv_data(exchange, symbol, "1m", 1)
|
|
|
|
if data and len(data) > 0:
|
|
return data[0]
|
|
else:
|
|
logger.warning("No candle data received")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch latest candle: {e}")
|
|
return None
|
|
|
|
async def fetch_ohlcv_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""
|
|
Fetch OHLCV data from exchange with error handling and retry logic.
|
|
|
|
Args:
|
|
exchange: The exchange instance
|
|
symbol: Trading pair symbol
|
|
timeframe: Candle timeframe
|
|
limit: Number of candles to fetch
|
|
|
|
Returns:
|
|
List of candle data or empty list on failure
|
|
"""
|
|
max_retries = 3
|
|
retry_delay = 5
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
logging.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt {attempt+1}/{max_retries})")
|
|
|
|
# Check if exchange has fetch_ohlcv method
|
|
if not hasattr(exchange, 'fetch_ohlcv'):
|
|
logging.error("Exchange does not support OHLCV data fetching")
|
|
return []
|
|
|
|
# Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor
|
|
try:
|
|
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
|
|
ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, limit=limit)
|
|
else:
|
|
# Run in executor to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
ohlcv = await loop.run_in_executor(
|
|
None,
|
|
lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Failed to fetch OHLCV data: {e}")
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
|
|
if not ohlcv or len(ohlcv) == 0:
|
|
logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})")
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
|
|
# Convert to list of lists format
|
|
data = []
|
|
for candle in ohlcv:
|
|
timestamp, open_price, high, low, close, volume = candle
|
|
data.append([timestamp, open_price, high, low, close, volume])
|
|
|
|
logging.info(f"Successfully fetched {len(data)} candles")
|
|
return data
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error fetching OHLCV data (attempt {attempt+1}/{max_retries}): {e}")
|
|
if attempt < max_retries - 1:
|
|
await asyncio.sleep(retry_delay)
|
|
|
|
logging.error(f"Failed to fetch OHLCV data after {max_retries} attempts")
|
|
return []
|
|
|
|
# Add this near the top of the file, after imports
|
|
def ensure_pytorch_compatibility():
|
|
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
|
|
try:
|
|
import torch
|
|
from torch.serialization import add_safe_globals
|
|
import numpy as np
|
|
|
|
# Add numpy scalar to safe globals for PyTorch 2.6+
|
|
add_safe_globals(['numpy._core.multiarray.scalar'])
|
|
logger.info("Added numpy scalar to PyTorch safe globals")
|
|
except (ImportError, AttributeError) as e:
|
|
logger.warning(f"Could not configure PyTorch compatibility: {e}")
|
|
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
|
|
|
|
# Call this function at the start of the main function
|
|
async def main():
|
|
# Ensure PyTorch compatibility
|
|
ensure_pytorch_compatibility()
|
|
|
|
parser = argparse.ArgumentParser(description='Trading Bot')
|
|
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
|
|
help='Operation mode: train, eval, or live')
|
|
parser.add_argument('--episodes', type=int, default=1000,
|
|
help='Number of episodes for training or evaluation')
|
|
parser.add_argument('--max_steps', type=int, default=1000,
|
|
help='Maximum steps per episode for training')
|
|
parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true',
|
|
help='Run in demo mode (paper trading) if true')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
|
help='Trading pair symbol')
|
|
parser.add_argument('--timeframe', type=str, default='1m',
|
|
help='Candle timeframe (1m, 5m, 15m, 1h, etc.)')
|
|
parser.add_argument('--leverage', type=int, default=50,
|
|
help='Leverage for futures trading')
|
|
parser.add_argument('--model', type=str, default=None,
|
|
help='Path to model file for evaluation or live trading')
|
|
parser.add_argument('--compact_save', action='store_true',
|
|
help='Use compact model saving (for low disk space)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Convert string boolean to actual boolean
|
|
demo_mode = args.demo.lower() == 'true'
|
|
|
|
# Get device (GPU or CPU)
|
|
device = get_device()
|
|
|
|
exchange = None
|
|
|
|
try:
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
|
|
# Create environment with updated parameters
|
|
env = TradingEnvironment(
|
|
initial_balance=INITIAL_BALANCE,
|
|
window_size=30,
|
|
leverage=args.leverage,
|
|
exchange_id='mexc',
|
|
symbol=args.symbol,
|
|
timeframe=args.timeframe
|
|
)
|
|
|
|
if args.mode == 'train':
|
|
# Fetch initial data for training
|
|
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
|
|
|
# Create agent with consistent parameters
|
|
# Note: Using STATE_SIZE and action_size=4 for consistency
|
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
|
|
|
# Train the agent
|
|
logger.info(f"Starting training for {args.episodes} episodes...")
|
|
stats = await train_agent(agent, env, num_episodes=args.episodes,
|
|
max_steps_per_episode=args.max_steps,
|
|
use_compact_save=args.compact_save)
|
|
|
|
elif args.mode == 'eval' or args.mode == 'live':
|
|
# Fetch initial data for the specified symbol and timeframe
|
|
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
|
|
|
# Determine model path
|
|
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
|
|
if not os.path.exists(model_path):
|
|
logger.error(f"Model file not found: {model_path}")
|
|
return
|
|
|
|
# Create agent with default parameters
|
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
|
|
|
# Try to load the model
|
|
try:
|
|
# Add numpy scalar to safe globals before loading
|
|
import numpy as np
|
|
from torch.serialization import add_safe_globals
|
|
|
|
# Add numpy scalar to safe globals
|
|
add_safe_globals(['numpy._core.multiarray.scalar'])
|
|
|
|
# Load the model
|
|
agent.load(model_path)
|
|
logger.info(f"Model loaded successfully from {model_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
|
|
# Ask user if they want to continue with a new model
|
|
if args.mode == 'live':
|
|
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
|
|
if confirmation.lower() != 'y':
|
|
logger.info("Live trading canceled by user")
|
|
return
|
|
logger.info("Continuing with a new model")
|
|
else:
|
|
logger.info("Continuing evaluation with a new model")
|
|
|
|
if args.mode == 'eval':
|
|
# Evaluate the agent
|
|
logger.info("Evaluating agent...")
|
|
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
|
|
|
|
elif args.mode == 'live':
|
|
# Start live trading
|
|
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
|
|
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
|
|
|
|
await live_trading(
|
|
agent=agent,
|
|
env=env,
|
|
exchange=exchange,
|
|
symbol=args.symbol,
|
|
timeframe=args.timeframe,
|
|
demo=demo_mode,
|
|
leverage=args.leverage
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main function: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
# Clean up exchange connection
|
|
if exchange:
|
|
try:
|
|
if hasattr(exchange, 'close'):
|
|
await exchange.close()
|
|
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
|
await exchange.client.close()
|
|
logger.info("Exchange connection closed")
|
|
except Exception as e:
|
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
|
|
|
# Add this function near the top with other utility functions
|
|
def create_candlestick_figure(data, trades=None, title="Trading Chart"):
|
|
"""Create a candlestick chart with trades marked"""
|
|
try:
|
|
if data is None or len(data) < 5:
|
|
logger.warning("Not enough data for candlestick chart")
|
|
return None
|
|
|
|
# Convert data to DataFrame if it's not already
|
|
if not isinstance(data, pd.DataFrame):
|
|
df = pd.DataFrame(data)
|
|
else:
|
|
df = data.copy()
|
|
|
|
# Ensure required columns exist
|
|
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
for col in required_columns:
|
|
if col not in df.columns:
|
|
logger.warning(f"Missing required column {col} for candlestick chart")
|
|
return None
|
|
|
|
# Format dates
|
|
if 'timestamp' in df.columns:
|
|
if isinstance(df['timestamp'].iloc[0], (int, float)):
|
|
# Convert timestamp to datetime if it's numeric
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
|
|
|
# Set timestamp as index if it's not already
|
|
if df.index.name != 'timestamp':
|
|
df.set_index('timestamp', inplace=True)
|
|
|
|
# Rename columns for mplfinance
|
|
df_mpf = df.copy()
|
|
if 'open' in df_mpf.columns:
|
|
df_mpf = df_mpf.rename(columns={
|
|
'open': 'Open',
|
|
'high': 'High',
|
|
'low': 'Low',
|
|
'close': 'Close',
|
|
'volume': 'Volume'
|
|
})
|
|
|
|
# Create a simple matplotlib figure instead
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8),
|
|
gridspec_kw={'height_ratios': [3, 1]})
|
|
|
|
# Plot candlesticks manually
|
|
for i in range(len(df_mpf)):
|
|
# Get date and prices
|
|
date = df_mpf.index[i]
|
|
open_price = df_mpf['Open'].iloc[i]
|
|
high_price = df_mpf['High'].iloc[i]
|
|
low_price = df_mpf['Low'].iloc[i]
|
|
close_price = df_mpf['Close'].iloc[i]
|
|
|
|
# Determine color based on price movement
|
|
color = 'green' if close_price >= open_price else 'red'
|
|
|
|
# Plot candle body
|
|
body_height = abs(close_price - open_price)
|
|
body_bottom = min(close_price, open_price)
|
|
ax1.bar(date, body_height, bottom=body_bottom, width=0.6,
|
|
color=color, alpha=0.6)
|
|
|
|
# Plot wick
|
|
ax1.plot([date, date], [low_price, high_price], color=color, linewidth=1)
|
|
|
|
# Plot volume
|
|
ax2.bar(df_mpf.index, df_mpf['Volume'], width=0.6, color='blue', alpha=0.5)
|
|
|
|
# Mark trades if available
|
|
if trades and len(trades) > 0:
|
|
for trade in trades:
|
|
if 'timestamp' not in trade or 'type' not in trade or 'price' not in trade:
|
|
continue
|
|
|
|
# Convert timestamp to datetime if needed
|
|
if isinstance(trade['timestamp'], (int, float)):
|
|
trade_time = pd.to_datetime(trade['timestamp'], unit='ms')
|
|
else:
|
|
trade_time = trade['timestamp']
|
|
|
|
# Determine marker color based on trade type
|
|
marker_color = 'green' if trade['type'].lower() == 'buy' else 'red'
|
|
|
|
# Add marker at trade price
|
|
ax1.scatter(trade_time, trade['price'], color=marker_color,
|
|
marker='^' if trade['type'].lower() == 'buy' else 'v',
|
|
s=100, zorder=5)
|
|
|
|
# Set title and labels
|
|
ax1.set_title(title)
|
|
ax1.set_ylabel('Price')
|
|
ax2.set_ylabel('Volume')
|
|
ax1.grid(True)
|
|
ax2.grid(True)
|
|
|
|
# Format x-axis
|
|
plt.setp(ax1.get_xticklabels(), visible=False)
|
|
|
|
# Adjust layout
|
|
plt.tight_layout()
|
|
|
|
return fig
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating candlestick figure: {e}")
|
|
return None
|
|
|
|
class CandlePatternCNN(nn.Module):
|
|
"""Convolutional neural network for detecting candlestick patterns"""
|
|
|
|
def __init__(self, input_channels=5, feature_dimension=512):
|
|
super(CandlePatternCNN, self).__init__()
|
|
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
|
|
self.relu1 = nn.ReLU()
|
|
self.pool1 = nn.MaxPool2d(kernel_size=2)
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
|
self.relu2 = nn.ReLU()
|
|
self.pool2 = nn.MaxPool2d(kernel_size=2)
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
|
self.relu3 = nn.ReLU()
|
|
self.pool3 = nn.MaxPool2d(kernel_size=2)
|
|
|
|
# Projection layers
|
|
self.fc1 = nn.Linear(128 * 4 * 4, 1024)
|
|
self.relu4 = nn.ReLU()
|
|
self.fc2 = nn.Linear(1024, feature_dimension)
|
|
|
|
# Initialize intermediate features as empty tensors, not as a dict
|
|
# This makes the model TorchScript compatible
|
|
self.feature_1m = torch.zeros(1, feature_dimension)
|
|
self.feature_1h = torch.zeros(1, feature_dimension)
|
|
self.feature_1d = torch.zeros(1, feature_dimension)
|
|
|
|
def forward(self, x_1m, x_1h, x_1d):
|
|
# Process 1m data
|
|
feat_1m = self.process_timeframe(x_1m)
|
|
|
|
# Process 1h data
|
|
feat_1h = self.process_timeframe(x_1h)
|
|
|
|
# Process 1d data
|
|
feat_1d = self.process_timeframe(x_1d)
|
|
|
|
# Store features as attributes instead of in a dictionary
|
|
self.feature_1m = feat_1m
|
|
self.feature_1h = feat_1h
|
|
self.feature_1d = feat_1d
|
|
|
|
# Concatenate features from different timeframes
|
|
combined_features = torch.cat([feat_1m, feat_1h, feat_1d], dim=1)
|
|
|
|
return combined_features
|
|
|
|
def process_timeframe(self, x):
|
|
"""Process a single timeframe batch of data"""
|
|
# Ensure proper shape for input, handle both batched and single inputs
|
|
if len(x.shape) == 3: # Single input, shape: [channels, height, width]
|
|
x = x.unsqueeze(0) # Add batch dimension
|
|
|
|
x = self.pool1(self.relu1(self.conv1(x)))
|
|
x = self.pool2(self.relu2(self.conv2(x)))
|
|
x = self.pool3(self.relu3(self.conv3(x)))
|
|
|
|
# Flatten the spatial dimensions for the fully connected layer
|
|
x = x.view(x.size(0), -1)
|
|
|
|
x = self.relu4(self.fc1(x))
|
|
x = self.fc2(x)
|
|
|
|
return x
|
|
|
|
def get_features(self):
|
|
"""Return features for each timeframe"""
|
|
# Use properties instead of dict for TorchScript compatibility
|
|
return self.feature_1m, self.feature_1h, self.feature_1d
|
|
|
|
# Add candle cache system
|
|
class CandleCache:
|
|
"""
|
|
Cache system for candles of different timeframes.
|
|
Reduces API calls by storing and updating candle data.
|
|
"""
|
|
def __init__(self):
|
|
self.candles = {
|
|
'1m': [],
|
|
'1h': [],
|
|
'1d': []
|
|
}
|
|
self.last_updated = {
|
|
'1m': None,
|
|
'1h': None,
|
|
'1d': None
|
|
}
|
|
# Add ticks channel for real-time data (WebSocket)
|
|
self.ticks = []
|
|
self.last_tick_time = None
|
|
|
|
def add_candles(self, timeframe, new_candles):
|
|
"""Add new candles to the cache"""
|
|
if not self.candles[timeframe]:
|
|
self.candles[timeframe] = new_candles
|
|
else:
|
|
# Find the last timestamp in our current cache
|
|
last_timestamp = self.candles[timeframe][-1][0]
|
|
|
|
# Add only candles newer than our last cached one
|
|
for candle in new_candles:
|
|
if candle[0] > last_timestamp:
|
|
self.candles[timeframe].append(candle)
|
|
|
|
self.last_updated[timeframe] = datetime.datetime.now()
|
|
|
|
def add_tick(self, tick_data):
|
|
"""Add a new tick to the ticks buffer"""
|
|
self.ticks.append(tick_data)
|
|
self.last_tick_time = datetime.datetime.now()
|
|
|
|
# Keep only the most recent 1000 ticks to prevent memory issues
|
|
if len(self.ticks) > 1000:
|
|
self.ticks = self.ticks[-1000:]
|
|
|
|
def get_ticks(self, limit=None):
|
|
"""Get the most recent ticks from the buffer"""
|
|
if not self.ticks:
|
|
return []
|
|
|
|
if limit and limit > 0:
|
|
return self.ticks[-limit:]
|
|
return self.ticks
|
|
|
|
def get_candles(self, timeframe, limit=300):
|
|
"""Get the most recent candles for a timeframe"""
|
|
if not self.candles[timeframe]:
|
|
return []
|
|
|
|
return self.candles[timeframe][-limit:]
|
|
|
|
def needs_update(self, timeframe, max_age_seconds):
|
|
"""Check if the cache needs to be updated"""
|
|
if not self.last_updated[timeframe]:
|
|
return True
|
|
|
|
age = (datetime.datetime.now() - self.last_updated[timeframe]).total_seconds()
|
|
return age > max_age_seconds
|
|
|
|
async def fetch_multi_timeframe_data(exchange, symbol, candle_cache):
|
|
"""Fetch candle data for multiple timeframes, using cache when possible"""
|
|
update_intervals = {
|
|
'1m': 60, # Update every 1 minute
|
|
'1h': 3600, # Update every 1 hour
|
|
'1d': 86400 # Update every 1 day
|
|
}
|
|
|
|
# TODO: For 1s/tick timeframes, we'll implement the exchange's WebSocket API
|
|
# for real-time data streaming in the future. This will enable ultra-low latency
|
|
# trading signals with minimal delay between market data reception and action execution.
|
|
# A WebSocket implementation is already prepared in the RealTimeDataStream class.
|
|
|
|
limits = {
|
|
'1m': 1000,
|
|
'1h': 500,
|
|
'1d': 300
|
|
}
|
|
|
|
for timeframe, interval in update_intervals.items():
|
|
if candle_cache.needs_update(timeframe, interval):
|
|
try:
|
|
logging.info(f"Fetching {timeframe} candle data for {symbol}")
|
|
candles = await fetch_ohlcv_data(exchange, symbol, timeframe, limits[timeframe])
|
|
candle_cache.add_candles(timeframe, candles)
|
|
logging.info(f"Fetched {len(candles)} {timeframe} candles")
|
|
except Exception as e:
|
|
logging.error(f"Error fetching {timeframe} candle data: {e}")
|
|
|
|
return {
|
|
'1m': candle_cache.get_candles('1m'),
|
|
'1h': candle_cache.get_candles('1h'),
|
|
'1d': candle_cache.get_candles('1d')
|
|
}
|
|
|
|
# Modify the LSTMAttentionDQN class to incorporate the CNN features
|
|
class LSTMAttentionDQN(nn.Module):
|
|
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
|
super(LSTMAttentionDQN, self).__init__()
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.hidden_size = hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
self.attention_heads = attention_heads
|
|
|
|
# LSTM layer
|
|
self.lstm = nn.LSTM(
|
|
input_size=state_size,
|
|
hidden_size=hidden_size,
|
|
num_layers=lstm_layers,
|
|
batch_first=True,
|
|
dropout=0.2 if lstm_layers > 1 else 0
|
|
)
|
|
|
|
# Multi-head self-attention
|
|
self.attention = nn.MultiheadAttention(
|
|
embed_dim=hidden_size,
|
|
num_heads=attention_heads,
|
|
dropout=0.1
|
|
)
|
|
|
|
# Value stream
|
|
self.value_stream = nn.Sequential(
|
|
nn.Linear(hidden_size, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 1)
|
|
)
|
|
|
|
# Advantage stream
|
|
self.advantage_stream = nn.Sequential(
|
|
nn.Linear(hidden_size, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, action_size)
|
|
)
|
|
|
|
# Fusion for multi-timeframe data
|
|
self.cnn_fusion = nn.Sequential(
|
|
nn.Linear(512 * 3, 1024), # 512 features from each of the 3 timeframes
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(1024, hidden_size)
|
|
)
|
|
|
|
# Initialize weights
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
elif isinstance(module, nn.LSTM):
|
|
for name, param in module.named_parameters():
|
|
if 'weight' in name:
|
|
nn.init.xavier_uniform_(param)
|
|
elif 'bias' in name:
|
|
nn.init.constant_(param, 0)
|
|
|
|
def forward(self, state, x_1m=None, x_1h=None, x_1d=None):
|
|
"""
|
|
Forward pass handling different input shapes and optional CNN features
|
|
|
|
Args:
|
|
state: Primary state vector (batch_size, sequence_length, state_size)
|
|
x_1m, x_1h, x_1d: Optional CNN features from different timeframes
|
|
|
|
Returns:
|
|
Q-values for each action
|
|
"""
|
|
batch_size = state.size(0)
|
|
|
|
# Handle CNN features if provided
|
|
if x_1m is not None and x_1h is not None and x_1d is not None:
|
|
# Ensure all CNN features have batch dimension
|
|
if len(x_1m.shape) == 2:
|
|
x_1m = x_1m.unsqueeze(0)
|
|
if len(x_1h.shape) == 2:
|
|
x_1h = x_1h.unsqueeze(0)
|
|
if len(x_1d.shape) == 2:
|
|
x_1d = x_1d.unsqueeze(0)
|
|
|
|
# Ensure batch dimensions match
|
|
if x_1m.size(0) != batch_size:
|
|
x_1m = x_1m.expand(batch_size, -1, -1) if x_1m.size(0) == 1 else x_1m[:batch_size]
|
|
if x_1h.size(0) != batch_size:
|
|
x_1h = x_1h.expand(batch_size, -1, -1) if x_1h.size(0) == 1 else x_1h[:batch_size]
|
|
if x_1d.size(0) != batch_size:
|
|
x_1d = x_1d.expand(batch_size, -1, -1) if x_1d.size(0) == 1 else x_1d[:batch_size]
|
|
|
|
# Check dimensions before concatenation
|
|
if x_1m.dim() == 3 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512:
|
|
# Already in correct format [batch, features]
|
|
cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1)
|
|
elif x_1m.dim() == 2 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512:
|
|
# Dimensions correct but missing batch dimension
|
|
cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1).unsqueeze(0)
|
|
else:
|
|
# Reshape to ensure correct dimensions
|
|
x_1m_flat = x_1m.reshape(batch_size, -1)
|
|
x_1h_flat = x_1h.reshape(batch_size, -1)
|
|
x_1d_flat = x_1d.reshape(batch_size, -1)
|
|
|
|
# Handle variable dimensions more gracefully
|
|
needed_features = 512
|
|
if x_1m_flat.size(1) < needed_features:
|
|
x_1m_flat = F.pad(x_1m_flat, (0, needed_features - x_1m_flat.size(1)))
|
|
else:
|
|
x_1m_flat = x_1m_flat[:, :needed_features]
|
|
|
|
if x_1h_flat.size(1) < needed_features:
|
|
x_1h_flat = F.pad(x_1h_flat, (0, needed_features - x_1h_flat.size(1)))
|
|
else:
|
|
x_1h_flat = x_1h_flat[:, :needed_features]
|
|
|
|
if x_1d_flat.size(1) < needed_features:
|
|
x_1d_flat = F.pad(x_1d_flat, (0, needed_features - x_1d_flat.size(1)))
|
|
else:
|
|
x_1d_flat = x_1d_flat[:, :needed_features]
|
|
|
|
# Concatenate
|
|
cnn_combined = torch.cat([x_1m_flat, x_1h_flat, x_1d_flat], dim=1)
|
|
|
|
# Use CNN fusion network to reduce dimension
|
|
cnn_features = self.cnn_fusion(cnn_combined)
|
|
|
|
# Reshape to match LSTM input shape
|
|
cnn_features = cnn_features.view(batch_size, 1, self.hidden_size)
|
|
|
|
# Combine with state input by concatenating along sequence dimension
|
|
if state.dim() < 3:
|
|
# If state is 2D [batch, features], reshape to 3D [batch, 1, features]
|
|
state = state.unsqueeze(1)
|
|
|
|
# Ensure state has proper dimensions
|
|
if state.size(2) != self.state_size:
|
|
# If state dimension doesn't match, reshape or pad
|
|
if state.size(2) > self.state_size:
|
|
state = state[:, :, :self.state_size]
|
|
else:
|
|
state = F.pad(state, (0, self.state_size - state.size(2)))
|
|
|
|
# Concatenate along sequence dimension
|
|
combined_input = torch.cat([state, cnn_features], dim=1)
|
|
else:
|
|
# Use only state input if CNN features not provided
|
|
combined_input = state
|
|
if combined_input.dim() < 3:
|
|
# If state is 2D [batch, features], reshape to 3D [batch, 1, features]
|
|
combined_input = combined_input.unsqueeze(1)
|
|
|
|
# Ensure state has proper dimensions
|
|
if combined_input.size(2) != self.state_size:
|
|
# If state dimension doesn't match, reshape or pad
|
|
if combined_input.size(2) > self.state_size:
|
|
combined_input = combined_input[:, :, :self.state_size]
|
|
else:
|
|
combined_input = F.pad(combined_input, (0, self.state_size - combined_input.size(2)))
|
|
|
|
# Pass through LSTM
|
|
lstm_out, _ = self.lstm(combined_input)
|
|
|
|
# Apply self-attention to LSTM output
|
|
# Transform to shape required by MultiheadAttention (seq_len, batch, hidden)
|
|
attn_input = lstm_out.transpose(0, 1)
|
|
attn_output, _ = self.attention(attn_input, attn_input, attn_input)
|
|
|
|
# Transform back to (batch, seq_len, hidden)
|
|
attn_output = attn_output.transpose(0, 1)
|
|
|
|
# Use last output after attention
|
|
attn_out = attn_output[:, -1]
|
|
|
|
# Value and advantage streams (dueling architecture)
|
|
value = self.value_stream(attn_out)
|
|
advantage = self.advantage_stream(attn_out)
|
|
|
|
# Combine value and advantage for Q-values
|
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
|
|
|
return q_values
|
|
|
|
def forward_realtime(self, x):
|
|
"""Simplified forward pass for realtime inference"""
|
|
# Adapt x to the right format if needed
|
|
if isinstance(x, np.ndarray):
|
|
x = torch.FloatTensor(x)
|
|
|
|
# Add batch dimension if not present
|
|
if x.dim() == 1:
|
|
x = x.unsqueeze(0)
|
|
|
|
# Add sequence dimension if not present
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(1)
|
|
|
|
# Basic forward pass
|
|
lstm_out, _ = self.lstm(x)
|
|
|
|
# Apply attention
|
|
attn_input = lstm_out.transpose(0, 1)
|
|
attn_output, _ = self.attention(attn_input, attn_input, attn_input)
|
|
attn_output = attn_output.transpose(0, 1)
|
|
|
|
# Get last output after attention
|
|
features = attn_output[:, -1]
|
|
|
|
# Value and advantage streams
|
|
value = self.value_stream(features)
|
|
advantage = self.advantage_stream(features)
|
|
|
|
# Combine for Q-values
|
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
|
|
|
return q_values
|
|
|
|
# Add this class after the CandleCache class
|
|
|
|
class RealTimeDataStream:
|
|
"""
|
|
Class for handling WebSocket API connections for ultra-low latency trading signals.
|
|
Provides real-time data streaming at 1-second intervals or faster for immediate trading decisions.
|
|
"""
|
|
|
|
def __init__(self, exchange, symbol, callback_fn=None):
|
|
"""
|
|
Initialize the real-time data stream with WebSocket connection
|
|
|
|
Args:
|
|
exchange: The exchange API client
|
|
symbol: Trading pair symbol (e.g. 'ETH/USDT')
|
|
callback_fn: Function to call when new data is received
|
|
"""
|
|
self.exchange = exchange
|
|
self.symbol = symbol
|
|
self.callback_fn = callback_fn
|
|
self.websocket = None
|
|
self.connected = False
|
|
self.last_tick_time = None
|
|
self.tick_buffer = []
|
|
self.latency_stats = []
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
# Statistics for monitoring performance
|
|
self.total_ticks = 0
|
|
self.avg_latency_ms = 0
|
|
self.max_latency_ms = 0
|
|
|
|
# Candle cache for storing processed data
|
|
self.candle_cache = CandleCache()
|
|
|
|
async def connect(self):
|
|
"""Connect to the exchange WebSocket API"""
|
|
# TODO: Implement actual WebSocket connection logic
|
|
self.logger.info(f"Connecting to WebSocket for {self.symbol}...")
|
|
try:
|
|
# This will be replaced with actual WebSocket connection code
|
|
self.websocket = None # Placeholder
|
|
self.connected = True
|
|
self.logger.info(f"Connected to WebSocket for {self.symbol}")
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"WebSocket connection error: {e}")
|
|
return False
|
|
|
|
async def subscribe(self):
|
|
"""Subscribe to relevant data channels"""
|
|
# TODO: Implement actual WebSocket subscription logic
|
|
self.logger.info(f"Subscribing to {self.symbol} ticks...")
|
|
try:
|
|
# This will be replaced with actual subscription code
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"WebSocket subscription error: {e}")
|
|
return False
|
|
|
|
async def process_message(self, message):
|
|
"""
|
|
Process incoming WebSocket message
|
|
|
|
Args:
|
|
message: The raw WebSocket message
|
|
|
|
Returns:
|
|
Processed tick data
|
|
"""
|
|
# TODO: Implement actual WebSocket message processing logic
|
|
try:
|
|
# Track tick receipt time for latency calculations
|
|
receive_time = time.time() * 1000 # milliseconds
|
|
|
|
# This is a placeholder - actual implementation will parse the message
|
|
# Example tick data structure (will vary by exchange):
|
|
tick_data = {
|
|
'timestamp': receive_time,
|
|
'price': 0.0, # Will be replaced with actual price
|
|
'volume': 0.0, # Will be replaced with actual volume
|
|
'side': 'buy', # or 'sell'
|
|
'exchange_time': 0, # Will be replaced with exchange timestamp
|
|
'latency_ms': 0 # Will be calculated
|
|
}
|
|
|
|
# Calculate latency (difference between our receive time and exchange time)
|
|
if 'exchange_time' in tick_data and tick_data['exchange_time'] > 0:
|
|
latency = receive_time - tick_data['exchange_time']
|
|
tick_data['latency_ms'] = latency
|
|
|
|
# Update latency statistics
|
|
self.latency_stats.append(latency)
|
|
if len(self.latency_stats) > 1000:
|
|
self.latency_stats = self.latency_stats[-1000:]
|
|
|
|
self.total_ticks += 1
|
|
self.avg_latency_ms = sum(self.latency_stats) / len(self.latency_stats)
|
|
self.max_latency_ms = max(self.max_latency_ms, latency)
|
|
|
|
# Store tick in buffer
|
|
self.tick_buffer.append(tick_data)
|
|
self.candle_cache.add_tick(tick_data)
|
|
self.last_tick_time = datetime.datetime.now()
|
|
|
|
# Keep buffer size reasonable
|
|
if len(self.tick_buffer) > 1000:
|
|
self.tick_buffer = self.tick_buffer[-1000:]
|
|
|
|
# Call callback function if provided
|
|
if self.callback_fn:
|
|
await self.callback_fn(tick_data)
|
|
|
|
return tick_data
|
|
except Exception as e:
|
|
self.logger.error(f"Error processing WebSocket message: {e}")
|
|
return None
|
|
|
|
def prepare_nn_input(self, model=None, state=None):
|
|
"""
|
|
Prepare network inputs from tick data for real-time inference
|
|
|
|
Args:
|
|
model: The neural network model
|
|
state: Current state representation
|
|
|
|
Returns:
|
|
Prepared tensors for model input
|
|
"""
|
|
# Get the most recent ticks
|
|
ticks = self.candle_cache.get_ticks(limit=300)
|
|
|
|
if not ticks or len(ticks) < 10:
|
|
# Not enough ticks for meaningful processing
|
|
return None
|
|
|
|
try:
|
|
# Extract price and volume data from ticks
|
|
prices = np.array([t['price'] for t in ticks if 'price' in t])
|
|
volumes = np.array([t['volume'] for t in ticks if 'volume' in t])
|
|
|
|
if len(prices) < 10:
|
|
return None
|
|
|
|
# Normalize data
|
|
min_price, max_price = prices.min(), prices.max()
|
|
price_range = max_price - min_price
|
|
if price_range == 0:
|
|
price_range = 1
|
|
|
|
normalized_prices = (prices - min_price) / price_range
|
|
|
|
# Create tick tensor - this is flexible-length data
|
|
# Format as sequence for time-series analysis
|
|
tick_data = torch.FloatTensor(normalized_prices).unsqueeze(0).unsqueeze(0)
|
|
|
|
return {
|
|
'state': state,
|
|
'ticks': tick_data
|
|
}
|
|
except Exception as e:
|
|
self.logger.error(f"Error preparing neural network input: {e}")
|
|
return None
|
|
|
|
def get_latency_stats(self):
|
|
"""Get statistics about WebSocket connection latency"""
|
|
return {
|
|
'total_ticks': self.total_ticks,
|
|
'avg_latency_ms': self.avg_latency_ms,
|
|
'max_latency_ms': self.max_latency_ms,
|
|
'last_update': self.last_tick_time.isoformat() if self.last_tick_time else None
|
|
}
|
|
|
|
async def close(self):
|
|
"""Close the WebSocket connection"""
|
|
if self.connected and self.websocket:
|
|
try:
|
|
# This will be replaced with actual close logic
|
|
self.connected = False
|
|
self.logger.info(f"Closed WebSocket connection for {self.symbol}")
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"Error closing WebSocket connection: {e}")
|
|
return False
|
|
|
|
class BacktestCandles(CandleCache):
|
|
"""
|
|
Special cache for backtesting that retrieves historical data from specific time periods
|
|
without contaminating the main cache. Used for running simulations "as if" we were
|
|
at a different point in time.
|
|
"""
|
|
def __init__(self, since_timestamp=None, until_timestamp=None):
|
|
"""
|
|
Initialize backtesting candle cache.
|
|
|
|
Args:
|
|
since_timestamp: Start timestamp for backtesting (milliseconds)
|
|
until_timestamp: End timestamp for backtesting (milliseconds)
|
|
"""
|
|
super().__init__()
|
|
# Since and until timestamps for backtesting
|
|
self.since_timestamp = since_timestamp
|
|
self.until_timestamp = until_timestamp
|
|
# Flag to indicate this is a backtesting cache
|
|
self.is_backtesting = True
|
|
# Optional name for backtesting period (e.g., "Day 1 - 24h ago")
|
|
self.period_name = None
|
|
|
|
async def fetch_historical_timeframe(self, exchange, symbol, timeframe, limit=1000):
|
|
"""
|
|
Fetch historical data for a specific timeframe and time period.
|
|
|
|
Args:
|
|
exchange: The exchange instance
|
|
symbol: Trading pair symbol
|
|
timeframe: Candle timeframe
|
|
limit: Number of candles to fetch
|
|
|
|
Returns:
|
|
Dictionary with candle data for the timeframe
|
|
"""
|
|
try:
|
|
logging.info(f"Fetching historical {timeframe} candles for {symbol} " +
|
|
f"(since: {self.format_timestamp(self.since_timestamp) if self.since_timestamp else 'None'}, " +
|
|
f"until: {self.format_timestamp(self.until_timestamp) if self.until_timestamp else 'None'})")
|
|
|
|
candles = await self.fetch_ohlcv_with_timerange(exchange, symbol, timeframe,
|
|
limit, self.since_timestamp, self.until_timestamp)
|
|
|
|
if candles:
|
|
# Store in the appropriate timeframe
|
|
self.candles[timeframe] = candles
|
|
self.last_updated[timeframe] = datetime.datetime.now()
|
|
logging.info(f"Fetched {len(candles)} historical {timeframe} candles for backtesting")
|
|
else:
|
|
logging.warning(f"No historical {timeframe} candles found for the specified time period")
|
|
|
|
return candles
|
|
except Exception as e:
|
|
logging.error(f"Error fetching historical {timeframe} data: {e}")
|
|
return []
|
|
|
|
async def fetch_all_timeframes(self, exchange, symbol):
|
|
"""
|
|
Fetch historical data for all timeframes.
|
|
|
|
Args:
|
|
exchange: The exchange instance
|
|
symbol: Trading pair symbol
|
|
|
|
Returns:
|
|
Dictionary with candle data for all timeframes
|
|
"""
|
|
# Define limits for each timeframe
|
|
limits = {
|
|
'1m': 1000,
|
|
'1h': 500,
|
|
'1d': 300
|
|
}
|
|
|
|
# Fetch data for each timeframe
|
|
for timeframe, limit in limits.items():
|
|
await self.fetch_historical_timeframe(exchange, symbol, timeframe, limit)
|
|
|
|
# Return the candles dictionary
|
|
return {
|
|
'1m': self.get_candles('1m'),
|
|
'1h': self.get_candles('1h'),
|
|
'1d': self.get_candles('1d')
|
|
}
|
|
|
|
async def fetch_ohlcv_with_timerange(self, exchange, symbol, timeframe, limit, since=None, until=None):
|
|
"""
|
|
Fetch OHLCV data within a specific time range.
|
|
|
|
Args:
|
|
exchange: The exchange instance
|
|
symbol: Trading pair symbol
|
|
timeframe: Candle timeframe
|
|
limit: Number of candles to fetch
|
|
since: Start timestamp (milliseconds)
|
|
until: End timestamp (milliseconds)
|
|
|
|
Returns:
|
|
List of candle data
|
|
"""
|
|
max_retries = 3
|
|
retry_delay = 5
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
logging.info(f"Fetching {limit} {timeframe} candles for {symbol} " +
|
|
f"(since: {self.format_timestamp(since) if since else 'None'}, " +
|
|
f"until: {self.format_timestamp(until) if until else 'None'}) " +
|
|
f"(attempt {attempt+1}/{max_retries})")
|
|
|
|
# Check if exchange has fetch_ohlcv method
|
|
if not hasattr(exchange, 'fetch_ohlcv'):
|
|
logging.error("Exchange does not support OHLCV data fetching")
|
|
return []
|
|
|
|
# Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor
|
|
try:
|
|
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
|
|
ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, since=since, limit=limit)
|
|
else:
|
|
# Run in executor to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
ohlcv = await loop.run_in_executor(
|
|
None,
|
|
lambda: exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Failed to fetch OHLCV data: {e}")
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
|
|
if not ohlcv or len(ohlcv) == 0:
|
|
logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})")
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
|
|
# Filter candles if until timestamp is provided
|
|
if until is not None:
|
|
ohlcv = [candle for candle in ohlcv if candle[0] <= until]
|
|
|
|
# Convert to list of lists format
|
|
data = []
|
|
for candle in ohlcv:
|
|
timestamp, open_price, high, low, close, volume = candle
|
|
data.append([timestamp, open_price, high, low, close, volume])
|
|
|
|
logging.info(f"Successfully fetched {len(data)} historical candles")
|
|
return data
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error fetching historical OHLCV data (attempt {attempt+1}/{max_retries}): {e}")
|
|
if attempt < max_retries - 1:
|
|
await asyncio.sleep(retry_delay)
|
|
|
|
logging.error(f"Failed to fetch historical OHLCV data after {max_retries} attempts")
|
|
return []
|
|
|
|
def format_timestamp(self, timestamp):
|
|
"""Format a timestamp for readable logging"""
|
|
if timestamp is None:
|
|
return "None"
|
|
|
|
try:
|
|
dt = datetime.datetime.fromtimestamp(timestamp / 1000.0)
|
|
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
|
except:
|
|
return str(timestamp)
|
|
|
|
async def train_with_backtesting(agent, env, symbol="ETH/USDT",
|
|
since_timestamp=None, until_timestamp=None,
|
|
num_episodes=10, max_steps_per_episode=1000,
|
|
period_name=None):
|
|
"""
|
|
Train agent with backtesting on historical data.
|
|
|
|
Args:
|
|
agent: The agent to train
|
|
env: Trading environment
|
|
symbol: Trading pair symbol
|
|
since_timestamp: Start timestamp for backtesting
|
|
until_timestamp: End timestamp for backtesting
|
|
num_episodes: Number of episodes to train
|
|
max_steps_per_episode: Maximum steps per episode
|
|
period_name: Name of the backtest period
|
|
|
|
Returns:
|
|
Training statistics dictionary
|
|
"""
|
|
# Create a backtesting candle cache
|
|
backtest_cache = BacktestCandles(since_timestamp, until_timestamp)
|
|
if period_name:
|
|
backtest_cache.period_name = period_name
|
|
logging.info(f"Starting backtesting for period: {period_name}")
|
|
|
|
# Initialize exchange for data fetching
|
|
exchange = None
|
|
try:
|
|
exchange = await initialize_exchange()
|
|
logging.info("Initialized exchange for backtesting")
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize exchange: {e}")
|
|
return None
|
|
|
|
# Initialize statistics tracking
|
|
stats = {
|
|
'period': period_name,
|
|
'since_timestamp': since_timestamp,
|
|
'until_timestamp': until_timestamp,
|
|
'episode_rewards': [],
|
|
'episode_lengths': [],
|
|
'balances': [],
|
|
'win_rates': [],
|
|
'episode_pnls': [],
|
|
'cumulative_pnl': [],
|
|
'drawdowns': [],
|
|
'trade_counts': [],
|
|
'loss_values': [],
|
|
'fees': [],
|
|
'net_pnl_after_fees': []
|
|
}
|
|
|
|
# Memory management function
|
|
def clean_memory():
|
|
"""Clean up memory to avoid memory leaks"""
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Fetch historical data for all timeframes
|
|
try:
|
|
clean_memory() # Clean memory before fetching data
|
|
candle_data = await backtest_cache.fetch_all_timeframes(exchange, symbol)
|
|
if not candle_data or not candle_data['1m']:
|
|
logging.error(f"No historical data available for backtesting period: {period_name}")
|
|
try:
|
|
await exchange.close()
|
|
except Exception as e:
|
|
logging.error(f"Error closing exchange: {e}")
|
|
return None
|
|
|
|
logging.info(f"Fetched historical data for backtesting: {len(candle_data['1m'])} minute candles")
|
|
except Exception as e:
|
|
logging.error(f"Failed to fetch historical data for backtesting: {e}")
|
|
try:
|
|
await exchange.close()
|
|
except Exception as exchange_err:
|
|
logging.error(f"Error closing exchange: {exchange_err}")
|
|
return None
|
|
|
|
# Track best models
|
|
best_reward = float('-inf')
|
|
best_pnl = float('-inf')
|
|
best_net_pnl = float('-inf')
|
|
|
|
# Make directory for backtesting models if it doesn't exist
|
|
os.makedirs('models/backtest', exist_ok=True)
|
|
|
|
# Start backtesting training loop
|
|
for episode in range(num_episodes):
|
|
try:
|
|
# Clean memory before starting a new episode
|
|
clean_memory()
|
|
|
|
# Reset environment
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
episode_losses = []
|
|
|
|
# Update CNN patterns with historical data
|
|
env.update_cnn_patterns(candle_data)
|
|
|
|
# Track consecutive errors for circuit breaker
|
|
consecutive_errors = 0
|
|
max_consecutive_errors = 5
|
|
|
|
# Episode loop
|
|
for step in range(max_steps_per_episode):
|
|
try:
|
|
# Select action using CNN-enhanced policy
|
|
action = agent.select_action(state, training=True, candle_data=candle_data)
|
|
|
|
# Take action
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store transition in replay memory
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
# Move to the next state
|
|
state = next_state
|
|
|
|
# Update episode reward
|
|
episode_reward += reward
|
|
|
|
# Learn from experience
|
|
if len(agent.memory) > BATCH_SIZE:
|
|
try:
|
|
loss = agent.learn()
|
|
if loss is not None:
|
|
episode_losses.append(loss)
|
|
# Reset consecutive errors counter on successful learning
|
|
consecutive_errors = 0
|
|
except Exception as e:
|
|
logging.error(f"Error during learning: {e}")
|
|
consecutive_errors += 1
|
|
if consecutive_errors >= max_consecutive_errors:
|
|
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
|
break
|
|
|
|
# Update target network periodically
|
|
if step % TARGET_UPDATE == 0:
|
|
agent.update_target_network()
|
|
|
|
# Clean memory periodically during long episodes
|
|
if step % 200 == 0 and step > 0:
|
|
clean_memory()
|
|
|
|
# End episode if done
|
|
if done:
|
|
break
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in training step: {e}")
|
|
consecutive_errors += 1
|
|
if consecutive_errors >= max_consecutive_errors:
|
|
logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
|
break
|
|
|
|
# Calculate statistics
|
|
mean_loss = np.mean(episode_losses) if episode_losses else 0
|
|
balance = env.balance
|
|
pnl = balance - env.initial_balance
|
|
fees = env.total_fees
|
|
net_pnl = pnl - fees
|
|
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
|
|
trade_count = env.trade_count if hasattr(env, 'trade_count') else 0
|
|
|
|
# Update epsilon for exploration
|
|
epsilon = agent.update_epsilon(episode)
|
|
|
|
# Update statistics
|
|
stats['episode_rewards'].append(episode_reward)
|
|
stats['episode_lengths'].append(step + 1)
|
|
stats['balances'].append(balance)
|
|
stats['win_rates'].append(win_rate)
|
|
stats['episode_pnls'].append(pnl)
|
|
stats['drawdowns'].append(env.max_drawdown)
|
|
stats['trade_counts'].append(trade_count)
|
|
stats['loss_values'].append(mean_loss)
|
|
stats['fees'].append(fees)
|
|
stats['net_pnl_after_fees'].append(net_pnl)
|
|
|
|
# Calculate and update cumulative PnL
|
|
if len(stats['episode_pnls']) > 0:
|
|
cumulative_pnl = sum(stats['episode_pnls'])
|
|
if 'cumulative_pnl' not in stats:
|
|
stats['cumulative_pnl'] = []
|
|
stats['cumulative_pnl'].append(cumulative_pnl)
|
|
if writer:
|
|
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
|
|
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
|
|
|
|
# Save model if this is the best reward or PnL
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
model_path = f"models/backtest/{period_name}_best_reward.pt" if period_name else "models/backtest/best_reward.pt"
|
|
try:
|
|
agent.save(model_path)
|
|
logging.info(f"New best reward: {best_reward:.2f}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving best reward model: {e}")
|
|
logging.info(f"New best reward: {best_reward:.2f} (model not saved)")
|
|
|
|
if pnl > best_pnl:
|
|
best_pnl = pnl
|
|
model_path = f"models/backtest/{period_name}_best_pnl.pt" if period_name else "models/backtest/best_pnl.pt"
|
|
try:
|
|
agent.save(model_path)
|
|
logging.info(f"New best PnL: ${best_pnl:.2f}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving best PnL model: {e}")
|
|
logging.info(f"New best PnL: ${best_pnl:.2f} (model not saved)")
|
|
|
|
# Save model if this is the best net PnL (after fees)
|
|
if net_pnl > best_net_pnl:
|
|
best_net_pnl = net_pnl
|
|
model_path = f"models/backtest/{period_name}_best_net_pnl.pt" if period_name else "models/backtest/best_net_pnl.pt"
|
|
try:
|
|
agent.save(model_path)
|
|
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving best net PnL model: {e}")
|
|
logging.info(f"New best Net PnL: ${best_net_pnl:.2f} (model not saved)")
|
|
|
|
# Save checkpoint periodically
|
|
if episode % 10 == 0:
|
|
try:
|
|
if use_compact_save:
|
|
compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt')
|
|
else:
|
|
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
|
|
except Exception as e:
|
|
logging.error(f"Error saving checkpoint model: {e}")
|
|
|
|
# Update epsilon
|
|
agent.update_epsilon(episode)
|
|
|
|
# Log training progress
|
|
logging.info(f"Episode {episode+1}/{num_episodes} | " +
|
|
f"Reward: {episode_reward:.2f} | " +
|
|
f"Balance: ${balance:.2f} | " +
|
|
f"PnL: ${pnl:.2f} | " +
|
|
f"Fees: ${fees:.2f} | " +
|
|
f"Net PnL: ${net_pnl:.2f} | " +
|
|
f"Win Rate: {win_rate:.2f} | " +
|
|
f"Trades: {trade_count} | " +
|
|
f"Loss: {mean_loss:.5f} | " +
|
|
f"Epsilon: {agent.epsilon:.4f}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in episode {episode}: {e}")
|
|
logging.error(traceback.format_exc())
|
|
continue
|
|
|
|
# Clean memory before saving final model
|
|
clean_memory()
|
|
|
|
# Save final model
|
|
if period_name:
|
|
try:
|
|
agent.save(f"models/backtest/{period_name}_final.pt")
|
|
logging.info(f"Saved final model for period: {period_name}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving final model: {e}")
|
|
|
|
# Save backtesting statistics
|
|
stats_file = f"backtest_stats_{period_name}.csv" if period_name else "backtest_stats.csv"
|
|
try:
|
|
with open(stats_file, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(['Episode', 'Reward', 'Balance', 'PnL', 'Fees', 'Net PnL', 'Win Rate', 'Trades', 'Loss'])
|
|
for i in range(len(stats['episode_rewards'])):
|
|
writer.writerow([
|
|
i+1,
|
|
stats['episode_rewards'][i],
|
|
stats['balances'][i],
|
|
stats['episode_pnls'][i],
|
|
stats['fees'][i],
|
|
stats['net_pnl_after_fees'][i],
|
|
stats['win_rates'][i],
|
|
stats['trade_counts'][i],
|
|
stats['loss_values'][i]
|
|
])
|
|
logging.info(f"Backtesting statistics saved to {stats_file}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving backtesting statistics: {e}")
|
|
|
|
# Close exchange connection
|
|
if exchange:
|
|
try:
|
|
await exchange.close()
|
|
logging.info("Exchange connection closed successfully")
|
|
except AttributeError:
|
|
# Some exchanges don't have a close method
|
|
logging.info("Exchange doesn't have a close method, skipping")
|
|
except Exception as e:
|
|
logging.error(f"Error closing exchange connection: {e}")
|
|
|
|
return stats
|
|
|
|
# Implement a robust save function to handle PyTorch serialization errors
|
|
def robust_save(model, path):
|
|
"""
|
|
Save a model with multiple fallback approaches to ensure file is saved
|
|
even in low disk space conditions.
|
|
"""
|
|
logger.info(f"Saving model to {path}.backup (attempt 1)")
|
|
backup_path = f"{path}.backup"
|
|
|
|
# Attempt 1: Regular save to backup file
|
|
try:
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'optimizer': model.optimizer.state_dict(),
|
|
'epsilon': model.epsilon
|
|
}
|
|
torch.save(checkpoint, backup_path)
|
|
logger.info(f"Successfully saved to {backup_path}")
|
|
|
|
# If successful, copy to final path
|
|
try:
|
|
shutil.copy2(backup_path, path)
|
|
logger.info(f"Copied backup to {path}")
|
|
logger.info(f"Model saved successfully to {path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to copy backup to main file: {str(e)}")
|
|
logger.info(f"Using backup file as the main save")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"First save attempt failed: {str(e)}")
|
|
|
|
# Attempt 2: Try with older pickle protocol
|
|
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
|
try:
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'optimizer': model.optimizer.state_dict(),
|
|
'epsilon': model.epsilon
|
|
}
|
|
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
|
|
logger.info(f"Successfully saved to {path} with protocol 2")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Second save attempt failed: {str(e)}")
|
|
|
|
# Attempt 3: Try without optimizer
|
|
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
|
try:
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'epsilon': model.epsilon
|
|
}
|
|
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
|
|
logger.info(f"Successfully saved to {path} without optimizer")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Third save attempt failed: {str(e)}")
|
|
|
|
# Attempt 4: Save model structure (as JSON) and parameters separately
|
|
logger.info(f"Saving model to {path} (attempt 4 - model structure as JSON)")
|
|
try:
|
|
# Save only essential model parameters as JSON
|
|
model_params = {
|
|
'epsilon': float(model.epsilon),
|
|
'state_size': model.state_size,
|
|
'action_size': model.action_size,
|
|
'hidden_size': model.hidden_size,
|
|
'lstm_layers': model.policy_net.lstm_layers if hasattr(model.policy_net, 'lstm_layers') else 2,
|
|
'attention_heads': model.policy_net.attention_heads if hasattr(model.policy_net, 'attention_heads') else 4
|
|
}
|
|
|
|
params_path = f"{path}.params.json"
|
|
with open(params_path, 'w') as f:
|
|
json.dump(model_params, f)
|
|
logger.info(f"Successfully saved model parameters to {params_path}")
|
|
|
|
# Now try to save a smaller version of the model without CNN components
|
|
# This is a more minimal save for recovery purposes
|
|
try:
|
|
# Create stripped down checkpoint with minimal components
|
|
minimal_checkpoint = {
|
|
'epsilon': model.epsilon,
|
|
'state_size': model.state_size,
|
|
'action_size': model.action_size,
|
|
'hidden_size': model.hidden_size
|
|
}
|
|
|
|
minimal_path = f"{path}.minimal"
|
|
torch.save(minimal_checkpoint, minimal_path, _use_new_zipfile_serialization=False, pickle_protocol=2)
|
|
logger.info(f"Successfully saved minimal checkpoint to {minimal_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Minimal checkpoint save failed: {str(e)}")
|
|
|
|
logger.info(f"Model saved successfully to {path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"All save attempts failed for {path}: {str(e)}")
|
|
return False
|
|
|
|
def cleanup_model_files(keep_best=True, keep_latest_n=5, aggressive=False):
|
|
"""
|
|
Delete old model files to free up disk space.
|
|
|
|
Args:
|
|
keep_best (bool): Whether to keep the best model files (reward, pnl, net_pnl)
|
|
keep_latest_n (int): Number of latest checkpoint files to keep
|
|
aggressive (bool): If True, apply more aggressive cleanup in very low disk scenarios
|
|
"""
|
|
try:
|
|
logging.info(f"Running model file cleanup: keep_best={keep_best}, keep_latest_n={keep_latest_n}, aggressive={aggressive}")
|
|
models_dir = "models"
|
|
|
|
# Get all files in the models directory
|
|
all_files = os.listdir(models_dir)
|
|
|
|
# Files to potentially delete
|
|
checkpoint_files = []
|
|
backup_files = []
|
|
params_files = []
|
|
dated_files = []
|
|
|
|
# Best files to keep if keep_best is True
|
|
best_patterns = [
|
|
"trading_agent_best_reward.pt",
|
|
"trading_agent_best_pnl.pt",
|
|
"trading_agent_best_net_pnl.pt",
|
|
"trading_agent_final.pt"
|
|
]
|
|
|
|
# Categorize files for potential deletion
|
|
for filename in all_files:
|
|
file_path = os.path.join(models_dir, filename)
|
|
|
|
# Skip directories
|
|
if os.path.isdir(file_path):
|
|
continue
|
|
|
|
# Skip current best files if keep_best is True
|
|
if keep_best and any(filename == pattern for pattern in best_patterns):
|
|
continue
|
|
|
|
# Check for different file types
|
|
if "checkpoint" in filename and filename.endswith(".pt"):
|
|
checkpoint_files.append((filename, os.path.getmtime(file_path), file_path))
|
|
elif filename.endswith(".backup"):
|
|
backup_files.append((filename, os.path.getmtime(file_path), file_path))
|
|
elif filename.endswith(".params.json"):
|
|
params_files.append((filename, os.path.getmtime(file_path), file_path))
|
|
elif "_2025" in filename or "_2024" in filename: # Files with date stamps
|
|
dated_files.append((filename, os.path.getmtime(file_path), file_path))
|
|
|
|
bytes_freed = 0
|
|
files_deleted = 0
|
|
|
|
# Process checkpoint files - keep the newest N
|
|
if len(checkpoint_files) > keep_latest_n:
|
|
# Sort by modification time (newest first)
|
|
checkpoint_files.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
# Keep the newest N files
|
|
files_to_delete = checkpoint_files[keep_latest_n:]
|
|
|
|
# Delete old checkpoint files
|
|
for _, _, file_path in files_to_delete:
|
|
try:
|
|
file_size = os.path.getsize(file_path)
|
|
os.remove(file_path)
|
|
bytes_freed += file_size
|
|
files_deleted += 1
|
|
logging.info(f"Deleted old checkpoint file: {file_path}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to delete file {file_path}: {str(e)}")
|
|
|
|
# If aggressive cleanup is enabled, remove more files
|
|
if aggressive:
|
|
# Delete all backup files except the newest one
|
|
if backup_files:
|
|
backup_files.sort(key=lambda x: x[1], reverse=True)
|
|
for _, _, file_path in backup_files[1:]: # Keep only newest backup
|
|
try:
|
|
file_size = os.path.getsize(file_path)
|
|
os.remove(file_path)
|
|
bytes_freed += file_size
|
|
files_deleted += 1
|
|
logging.info(f"Deleted old backup file: {file_path}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to delete file {file_path}: {str(e)}")
|
|
|
|
# Delete all dated files (these are typically archived models)
|
|
for _, _, file_path in dated_files:
|
|
try:
|
|
file_size = os.path.getsize(file_path)
|
|
os.remove(file_path)
|
|
bytes_freed += file_size
|
|
files_deleted += 1
|
|
logging.info(f"Deleted dated model file: {file_path}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to delete file {file_path}: {str(e)}")
|
|
|
|
logging.info(f"Cleanup complete. Deleted {files_deleted} files, freed {bytes_freed / (1024*1024):.2f} MB")
|
|
|
|
# Check available disk space after cleanup
|
|
try:
|
|
if platform.system() == 'Windows':
|
|
free_bytes = ctypes.c_ulonglong(0)
|
|
ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(os.path.abspath(models_dir)), None, None, ctypes.pointer(free_bytes))
|
|
free_mb = free_bytes.value / (1024 * 1024)
|
|
else:
|
|
st = os.statvfs(os.path.abspath(models_dir))
|
|
free_mb = (st.f_bavail * st.f_frsize) / (1024 * 1024)
|
|
|
|
logging.info(f"Available disk space after cleanup: {free_mb:.2f} MB")
|
|
|
|
# If space is still low, recommend aggressive cleanup
|
|
if free_mb < 200 and not aggressive: # Less than 200MB available
|
|
logging.warning("Disk space still critically low. Consider using aggressive cleanup.")
|
|
except Exception as e:
|
|
logging.error(f"Error checking disk space: {str(e)}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error during file cleanup: {str(e)}")
|
|
logging.error(traceback.format_exc())
|
|
|
|
def compact_save(model, optimizer, reward, epsilon, state_size, action_size, hidden_size, path, use_quantization=False):
|
|
"""
|
|
Save a model in a compact format suitable for low disk space environments.
|
|
Includes fallbacks if the primary save method fails.
|
|
|
|
Args:
|
|
model: The model to save
|
|
optimizer: The optimizer to save
|
|
reward: The current reward
|
|
epsilon: The current epsilon value
|
|
state_size: The state size
|
|
action_size: The action size
|
|
hidden_size: The hidden size
|
|
path: The path to save to
|
|
use_quantization: Whether to use quantization to reduce model size
|
|
|
|
Returns:
|
|
bool: Whether the save was successful
|
|
"""
|
|
try:
|
|
# Create minimal checkpoint with essential data only
|
|
checkpoint = {
|
|
'model_state_dict': model.state_dict(),
|
|
'epsilon': epsilon,
|
|
'state_size': state_size,
|
|
'action_size': action_size,
|
|
'hidden_size': hidden_size
|
|
}
|
|
|
|
# Apply quantization if requested
|
|
if use_quantization:
|
|
try:
|
|
logging.info(f"Attempting quantized save to {path}")
|
|
# Quantize model to int8
|
|
quantized_model = torch.quantization.quantize_dynamic(
|
|
model, # the original model
|
|
{torch.nn.Linear}, # a set of layers to dynamically quantize
|
|
dtype=torch.qint8 # the target dtype for quantized weights
|
|
)
|
|
|
|
# Create quantized checkpoint
|
|
quantized_checkpoint = {
|
|
'model_state_dict': quantized_model.state_dict(),
|
|
'epsilon': epsilon,
|
|
'state_size': state_size,
|
|
'action_size': action_size,
|
|
'hidden_size': hidden_size,
|
|
'is_quantized': True
|
|
}
|
|
|
|
# Save with older pickle protocol and disable new zipfile serialization
|
|
torch.save(quantized_checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
|
|
logging.info(f"Quantized compact save successful to {path}")
|
|
return True
|
|
except Exception as e:
|
|
logging.warning(f"Quantized save failed, falling back to regular save: {str(e)}")
|
|
# Fall back to regular save if quantization fails
|
|
|
|
# Regular save with older pickle protocol and no zipfile serialization
|
|
torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2)
|
|
logging.info(f"Compact save successful to {path}")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Compact save failed: {str(e)}")
|
|
logging.error(traceback.format_exc())
|
|
|
|
# Fallback: Save just the parameters as JSON if we can't save the full model
|
|
try:
|
|
params = {
|
|
'epsilon': epsilon,
|
|
'state_size': state_size,
|
|
'action_size': action_size,
|
|
'hidden_size': hidden_size
|
|
}
|
|
json_path = f"{path}.params.json"
|
|
with open(json_path, 'w') as f:
|
|
json.dump(params, f)
|
|
logging.info(f"Saved minimal parameters to {json_path}")
|
|
return False
|
|
except Exception as json_e:
|
|
logging.error(f"JSON parameter save failed: {str(json_e)}")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description='Trading Bot')
|
|
parser.add_argument('--mode', type=str, default='train', help='Mode: train, test, live')
|
|
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
|
parser.add_argument('--max_steps', type=int, default=1000, help='Maximum steps per episode')
|
|
parser.add_argument('--update_interval', type=int, default=10, help='Target network update interval')
|
|
parser.add_argument('--training_iterations', type=int, default=10, help='Number of training iterations per step')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
|
|
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for candlestick data')
|
|
parser.add_argument('--compact_save', action='store_true', help='Use compact save to reduce disk usage')
|
|
parser.add_argument('--use_quantization', action='store_true', help='Use model quantization for even smaller file sizes')
|
|
parser.add_argument('--cleanup', action='store_true', help='Clean up old model files before training')
|
|
parser.add_argument('--aggressive_cleanup', action='store_true', help='Perform aggressive cleanup to free more space')
|
|
parser.add_argument('--keep_latest', type=int, default=5, help='Number of latest checkpoint files to keep when cleaning up')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Import platform and ctypes for disk space checking
|
|
import platform
|
|
import ctypes
|
|
|
|
# Run cleanup if requested
|
|
if args.cleanup:
|
|
cleanup_model_files(keep_best=True, keep_latest_n=args.keep_latest, aggressive=args.aggressive_cleanup)
|
|
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Program terminated by user") |