gogo2/main.py
2025-03-19 04:18:55 +02:00

3560 lines
148 KiB
Python

import os
import time
import json
import numpy as np
import pandas as pd
from datetime import datetime
import random
import logging
import asyncio
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
from dotenv import load_dotenv
import ccxt
import websockets
from torch.utils.tensorboard import SummaryWriter
import torch.cuda.amp as amp # Add this import at the top
from sklearn.preprocessing import MinMaxScaler
import copy
import argparse
import traceback
import io
import matplotlib.dates as mdates
from matplotlib.figure import Figure
from PIL import Image
import matplotlib.pyplot as mpf
import matplotlib.gridspec as gridspec
import datetime
from realtime import BinanceWebSocket, BinanceHistoricalData
from datetime import datetime as dt
# Add Dash-related imports
import dash
from dash import html, dcc, callback_context
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from threading import Thread
import socket
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
)
logger = logging.getLogger("trading_bot")
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
# Constants
INITIAL_BALANCE = 100 # USD
MAX_LEVERAGE = 100
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
MEMORY_SIZE = 100000
BATCH_SIZE = 64
GAMMA = 0.99 # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 10000
STATE_SIZE = 64 # Size of our state representation
LEARNING_RATE = 1e-4
TARGET_UPDATE = 10 # Update target network every 10 episodes
# Experience replay tuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
# Add this function near the top of the file, after the imports but before any classes
def find_local_extrema(prices, window=5):
"""Find local minima (bottoms) and maxima (tops) in price data"""
bottoms = []
tops = []
if len(prices) < window * 2 + 1:
return bottoms, tops
for i in range(window, len(prices) - window):
# Check if this is a local minimum (bottom)
if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \
all(prices[i] <= prices[i+j] for j in range(1, window+1)):
bottoms.append(i)
# Check if this is a local maximum (top)
if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \
all(prices[i] >= prices[i+j] for j in range(1, window+1)):
tops.append(i)
return bottoms, tops
class ReplayMemory:
def __init__(self, capacity):
self.memory = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.memory.append(Experience(state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
super(DQN, self).__init__()
self.state_size = state_size
self.hidden_size = hidden_size
self.lstm_layers = lstm_layers
# Initial feature extraction
self.fc1 = nn.Linear(state_size, hidden_size)
# Use LayerNorm instead of BatchNorm for more stability with varying batch sizes
self.ln1 = nn.LayerNorm(hidden_size)
self.dropout1 = nn.Dropout(0.2)
# LSTM layer for sequential data
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2)
# Attention mechanism
self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
# Output layers with increased capacity
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm
self.dropout2 = nn.Dropout(0.2)
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
# Dueling DQN architecture
self.value_stream = nn.Linear(hidden_size // 2, 1)
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
# Transformer encoder for more complex pattern recognition
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
def forward(self, x):
batch_size = x.size(0) if x.dim() > 1 else 1
# Ensure input has correct shape
if x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension
# Check if state size matches expected input size
if x.size(1) != self.state_size:
# Handle mismatched input by either truncating or padding
if x.size(1) > self.state_size:
x = x[:, :self.state_size] # Truncate
else:
# Pad with zeros
padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device)
x = torch.cat([x, padding], dim=1)
# Initial feature extraction
x = self.fc1(x)
x = F.relu(self.ln1(x)) # LayerNorm works with any batch size
x = self.dropout1(x)
# Reshape for LSTM
x_lstm = x.unsqueeze(1) if x.dim() == 2 else x
# Process through LSTM
lstm_out, _ = self.lstm(x_lstm)
lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1]
# Process through transformer for more complex patterns
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
# Combine LSTM and transformer outputs
x = lstm_out + transformer_out
# Final layers
x = self.fc2(x)
x = F.relu(self.ln2(x)) # LayerNorm works with any batch size
x = self.dropout2(x)
x = F.relu(self.fc3(x))
# Dueling architecture
value = self.value_stream(x)
advantages = self.advantage_stream(x)
qvals = value + (advantages - advantages.mean(dim=1, keepdim=True))
return qvals
class PricePredictionModel(nn.Module):
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
super(PricePredictionModel, self).__init__()
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
self.fc = nn.Linear(hidden_size, output_size)
self.scaler = MinMaxScaler(feature_range=(0, 1))
self.is_fitted = False
def forward(self, x):
# x shape: [batch_size, seq_len, 1]
lstm_out, _ = self.lstm(x)
# Use the last time step output
predictions = self.fc(lstm_out[:, -1, :])
return predictions
def preprocess(self, data):
# Reshape data for scaler
data_reshaped = np.array(data).reshape(-1, 1)
# Fit scaler if not already fitted
if not self.is_fitted:
self.scaler.fit(data_reshaped)
self.is_fitted = True
# Transform data
scaled_data = self.scaler.transform(data_reshaped)
return scaled_data
def postprocess(self, scaled_predictions):
# Inverse transform to get actual price values
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
def predict_next_candles(self, price_history, num_candles=5):
if len(price_history) < 30: # Need enough history
return np.zeros(num_candles)
# Preprocess data
scaled_data = self.preprocess(price_history)
# Create sequence
sequence = scaled_data[-30:].reshape(1, 30, 1)
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
# Get predictions
with torch.no_grad():
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
# Postprocess predictions
predictions = self.postprocess(scaled_predictions)
return predictions
def train_on_new_data(self, price_history, optimizer, epochs=10):
if len(price_history) < 35: # Need enough history for training
return 0.0
# Preprocess data
scaled_data = self.preprocess(price_history)
# Create sequences and targets
sequences = []
targets = []
for i in range(len(scaled_data) - 35):
# Sequence: 30 time steps
seq = scaled_data[i:i+30]
# Target: next 5 time steps
target = scaled_data[i+30:i+35].flatten()
sequences.append(seq)
targets.append(target)
if not sequences: # If no sequences were created
return 0.0
# Convert to tensors
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
# Training loop
total_loss = 0
for _ in range(epochs):
# Forward pass
predictions = self(sequences_tensor)
# Calculate loss
loss = F.mse_loss(predictions, targets_tensor)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / epochs
class TradingEnvironment:
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
"""Initialize the trading environment"""
self.initial_balance = initial_balance
self.balance = initial_balance
self.window_size = window_size
self.demo = demo
self.data = []
self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.trades = []
self.win_count = 0
self.loss_count = 0
self.total_pnl = 0.0
self.episode_pnl = 0.0
self.peak_balance = initial_balance
self.max_drawdown = 0.0
self.current_step = 0
self.current_price = 0
# For tracking signals for visualization
self.trade_signals = []
# Initialize features
self.features = {
'price': [],
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
# Initialize price predictor
self.price_predictor = None
self.predicted_prices = np.array([])
# Initialize optimal trade tracking
self.optimal_bottoms = []
self.optimal_tops = []
self.optimal_signals = np.array([])
# Add these new attributes
self.leverage = MAX_LEVERAGE
self.futures_symbol = "ETH_USDT" # Example futures symbol
self.position_mode = "hedge" # For simultaneous long/short positions
self.margin_mode = "cross" # Cross margin mode
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.trades = []
self.win_count = 0
self.loss_count = 0
self.episode_pnl = 0.0
self.peak_balance = self.initial_balance
self.max_drawdown = 0.0
self.current_step = 0
# Keep data but reset current position
if len(self.data) > self.window_size:
self.current_step = self.window_size
self.current_price = self.data[self.current_step]['close']
# Reset trade signals
self.trade_signals = []
return self.get_state()
def add_data(self, candle):
"""Add a new candle to the data"""
self.data.append(candle)
self._update_features()
self.current_price = candle['close']
def _initialize_features(self):
"""Initialize technical indicators and features"""
if len(self.data) < 30:
return
# Convert data to pandas DataFrame for easier calculation
df = pd.DataFrame(self.data)
# Basic price and volume
self.features['price'] = df['close'].values
self.features['volume'] = df['volume'].values
# Calculate RSI (14 periods)
delta = df['close'].diff()
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
rs = gain / loss
self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values
# Calculate MACD
ema12 = df['close'].ewm(span=12, adjust=False).mean()
ema26 = df['close'].ewm(span=26, adjust=False).mean()
macd = ema12 - ema26
signal = macd.ewm(span=9, adjust=False).mean()
self.features['macd'] = macd.values
self.features['macd_signal'] = signal.values
self.features['macd_hist'] = (macd - signal).values
# Calculate Bollinger Bands
sma20 = df['close'].rolling(window=20).mean()
std20 = df['close'].rolling(window=20).std()
self.features['bollinger_upper'] = (sma20 + 2 * std20).values
self.features['bollinger_mid'] = sma20.values
self.features['bollinger_lower'] = (sma20 - 2 * std20).values
# Calculate Stochastic Oscillator
low_14 = df['low'].rolling(window=14).min()
high_14 = df['high'].rolling(window=14).max()
k = 100 * ((df['close'] - low_14) / (high_14 - low_14))
self.features['stoch_k'] = k.values
self.features['stoch_d'] = k.rolling(window=3).mean().values
# Calculate EMAs
self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values
self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values
# Calculate ATR
high_low = df['high'] - df['low']
high_close = (df['high'] - df['close'].shift()).abs()
low_close = (df['low'] - df['close'].shift()).abs()
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values
def _update_features(self):
"""Update technical indicators with new data"""
self._initialize_features() # Recalculate all features
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""Fetch initial historical data for the environment"""
try:
logger.info(f"Fetching initial data for {symbol}")
# Use the refactored fetch method
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
# Update environment with fetched data
if data:
self.data = data
self._initialize_features()
logger.info(f"Initialized environment with {len(data)} candles")
else:
logger.warning("No initial data received")
return len(data) > 0
except Exception as e:
logger.error(f"Error fetching initial data: {e}")
return False
def step(self, action):
"""Take an action in the environment and return the next state, reward, and done flag"""
# Check if we have enough data
if self.current_step >= len(self.data) - 1:
# We've reached the end of data
done = True
next_state = self.get_state()
info = {
'action': 'none',
'price': self.current_price,
'balance': self.balance,
'position': self.position,
'pnl': self.total_pnl
}
return next_state, 0, done, info
# Store current price before taking action
self.current_price = self.data[self.current_step]['close']
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
reward = self.calculate_reward(action)
# Record trade signal for visualization
if action > 0: # If not HOLD
signal_type = None
if action == 1: # BUY/LONG
signal_type = 'buy'
elif action == 2: # SELL/SHORT
signal_type = 'sell'
elif action == 3: # CLOSE
if self.position == 'long':
signal_type = 'close_long'
elif self.position == 'short':
signal_type = 'close_short'
if signal_type:
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.current_price,
'type': signal_type,
'balance': self.balance,
'pnl': self.total_pnl
})
# Check for stop loss / take profit hits
self.check_sl_tp()
# Move to next step
self.current_step += 1
done = self.current_step >= len(self.data) - 1
# Get new state
next_state = self.get_state()
# Create info dictionary
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': self.current_price,
'balance': self.balance,
'position': self.position,
'pnl': self.total_pnl
}
return next_state, reward, done, info
def check_sl_tp(self):
"""Check if stop loss or take profit has been hit"""
if self.position == 'flat':
return
if self.position == 'long':
# Check stop loss
if self.current_price <= self.stop_loss:
# Stop loss hit
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.stop_loss,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'stop_loss'
})
# Update win/loss count
self.loss_count += 1
logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.stop_loss,
'type': 'stop_loss_long',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif self.current_price >= self.take_profit:
# Take profit hit
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.take_profit,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'take_profit'
})
# Update win/loss count
self.win_count += 1
logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.take_profit,
'type': 'take_profit_long',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
elif self.position == 'short':
# Check stop loss
if self.current_price >= self.stop_loss:
# Stop loss hit
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.stop_loss,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'stop_loss'
})
# Update win/loss count
self.loss_count += 1
logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.stop_loss,
'type': 'stop_loss_short',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif self.current_price <= self.take_profit:
# Take profit hit
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
# Record trade
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.take_profit,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'take_profit'
})
# Update win/loss count
self.win_count += 1
logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.take_profit,
'type': 'take_profit_short',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
def get_state(self):
"""Create state representation for the agent with enhanced features"""
# Ensure we have enough data
if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0:
# Return zeros if not enough data
return np.zeros(STATE_SIZE)
# Create a normalized state vector with recent price action and indicators
state_components = []
# Safely get the latest price
try:
latest_price = self.features['price'][-1]
except IndexError:
# If we can't get the latest price, return zeros
return np.zeros(STATE_SIZE)
# Safely get price features
try:
# Price features (normalize recent prices by the latest price)
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
state_components.append(price_features)
except (IndexError, ZeroDivisionError):
# If we can't get price features, use zeros
state_components.append(np.zeros(10))
# Safely get volume features
try:
# Volume features (normalize by max volume)
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
vol_features = np.array(self.features['volume'][-5:]) / max_vol
state_components.append(vol_features)
except (IndexError, ZeroDivisionError):
# If we can't get volume features, use zeros
state_components.append(np.zeros(5))
# Technical indicators
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
state_components.append(rsi)
# MACD (normalize)
macd_vals = np.array(self.features['macd'][-3:])
macd_signal = np.array(self.features['macd_signal'][-3:])
macd_hist = np.array(self.features['macd_hist'][-3:])
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
macd_norm = macd_vals / macd_scale
macd_signal_norm = macd_signal / macd_scale
macd_hist_norm = macd_hist / macd_scale
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
# Bollinger position (where is price relative to bands)
bb_upper = np.array(self.features['bollinger_upper'][-3:])
bb_lower = np.array(self.features['bollinger_lower'][-3:])
bb_mid = np.array(self.features['bollinger_mid'][-3:])
price = np.array(self.features['price'][-3:])
# Calculate position of price within Bollinger Bands (0 to 1)
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
state_components.append(np.array(bb_pos))
# Stochastic oscillator
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
# Add predicted prices (if available)
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
# Normalize predictions relative to current price
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
state_components.append(pred_norm)
else:
# Add zeros if no predictions
state_components.append(np.zeros(3))
# Add extrema signals (if available)
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
# Get recent signals
idx = len(self.optimal_signals) - 5
if idx < 0:
idx = 0
recent_signals = self.optimal_signals[idx:idx+5]
# Pad if needed
if len(recent_signals) < 5:
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
state_components.append(recent_signals)
else:
# Add zeros if no signals
state_components.append(np.zeros(5))
# Position info
position_info = np.zeros(5)
if self.position == 'long':
position_info[0] = 1.0 # Position is long
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
position_info[4] = self.position_size / self.balance # Position size relative to balance
elif self.position == 'short':
position_info[0] = -1.0 # Position is short
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
position_info[4] = self.position_size / self.balance # Position size relative to balance
state_components.append(position_info)
# NEW FEATURES START HERE
# 1. Price momentum features (rate of change over different periods)
if len(self.features['price']) >= 20:
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
momentum_features = np.array([roc_5, roc_10, roc_20])
state_components.append(momentum_features)
else:
state_components.append(np.zeros(3))
# 2. Volatility features
if len(self.features['price']) >= 20:
# Calculate price returns
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
# Calculate volatility (standard deviation of returns)
volatility = np.std(returns)
# Calculate normalized high-low range
high_low_range = np.mean([
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
for i in range(max(0, len(self.data)-5), len(self.data))
]) if len(self.data) > 0 else 0
# ATR normalized by price
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
volatility_features = np.array([volatility, high_low_range, atr_norm])
state_components.append(volatility_features)
else:
state_components.append(np.zeros(3))
# 3. Market regime features
if len(self.features['price']) >= 50:
# Trend strength (ADX-like measure)
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
trend_strength = abs(ema9 - ema21) / ema21
# Detect if in range or trending
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
# Detect if near support/resistance
near_support = 1.0 if self.is_near_support() else 0.0
near_resistance = 1.0 if self.is_near_resistance() else 0.0
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
state_components.append(market_regime)
else:
state_components.append(np.zeros(5))
# 4. Trade history features
if len(self.trades) > 0:
# Recent win/loss ratio
recent_trades = self.trades[-min(10, len(self.trades)):]
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
# Average profit/loss
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
# Normalize by balance
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
# Last trade result
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
state_components.append(trade_history)
else:
state_components.append(np.zeros(4))
# Combine all features
state = np.concatenate([comp.flatten() for comp in state_components])
# Replace any NaN or infinite values
state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure the state has the correct size
if len(state) != STATE_SIZE:
logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}")
# Pad or truncate to match expected size
if len(state) < STATE_SIZE:
state = np.pad(state, (0, STATE_SIZE - len(state)))
else:
state = state[:STATE_SIZE]
return state
def get_expanded_state_size(self):
"""Calculate the size of the expanded state representation"""
# Create a dummy state to get its size
state = self.get_state()
return len(state)
async def expand_model_with_new_features(agent, env):
"""Expand the model to handle new features without retraining from scratch"""
# Get the new state size
new_state_size = env.get_expanded_state_size()
# Only expand if the new state size is larger
if new_state_size > agent.state_size:
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
# Expand the model
success = agent.expand_model(
new_state_size=new_state_size,
new_hidden_size=512, # Increase hidden size for more capacity
new_lstm_layers=3, # More layers for deeper patterns
new_attention_heads=8 # More attention heads for complex relationships
)
if success:
logger.info(f"Model successfully expanded to handle {new_state_size} features")
return True
else:
logger.error("Failed to expand model")
return False
else:
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
return True
def calculate_reward(self, action):
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
reward = 0
# Base reward for actions
if action == 0: # HOLD
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
elif action == 1: # BUY/LONG
if self.position == 'flat':
# Opening a long position
self.position = 'long'
self.entry_price = self.current_price
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal buy point (bottom)
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
reward += 3.0 # Increased bonus for buying at a bottom
# Check for volume spike (indicating potential big movement)
if len(self.features['volume']) > 5:
avg_volume = np.mean(self.features['volume'][-5:-1])
current_volume = self.features['volume'][-1]
if current_volume > avg_volume * 1.5:
reward += 2.0 # Bonus for entering during high volume
# Check for price action signals
if self.features['rsi'][-1] < 30: # Oversold condition
reward += 1.5 # Bonus for buying at oversold levels
# Check if we're buying in a clear uptrend (good)
if self.is_uptrend():
reward += 1.0 # Bonus for buying in uptrend
elif self.is_downtrend():
reward -= 0.25 # Reduced penalty for buying in downtrend
else:
reward += 0.2 # Small reward for opening a position
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif self.position == 'short':
# Close short and open long
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
trade_duration = len(self.features['price']) - self.entry_index
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': trade_duration,
'market_direction': self.get_market_direction()
})
# Reward based on PnL with stronger penalties for losses
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
# Stronger penalty for losses, scaled by the size of the loss
loss_penalty = 1.0 + abs(pnl_dollar) / 5
reward -= loss_penalty
self.loss_count += 1
# Extra penalty for closing a losing trade too quickly
if trade_duration < 5:
reward -= 0.5 # Penalty for very short losing trades
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Now open long
self.position = 'long'
self.entry_price = self.current_price
self.entry_index = len(self.features['price']) - 1
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal buy point
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
reward += 2.0 # Bonus for buying at a bottom
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif action == 2: # SELL/SHORT
if self.position == 'flat':
# Opening a short position
self.position = 'short'
self.entry_price = self.current_price
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal sell point (top)
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
reward += 3.0 # Increased bonus for selling at a top
# Check for volume spike
if len(self.features['volume']) > 5:
avg_volume = np.mean(self.features['volume'][-5:-1])
current_volume = self.features['volume'][-1]
if current_volume > avg_volume * 1.5:
reward += 2.0 # Bonus for entering during high volume
# Check for price action signals
if self.features['rsi'][-1] > 70: # Overbought condition
reward += 1.5 # Bonus for selling at overbought levels
# Check if we're selling in a clear downtrend (good)
if self.is_downtrend():
reward += 1.0 # Bonus for selling in downtrend
elif self.is_uptrend():
reward -= 0.25 # Reduced penalty for selling in uptrend
else:
reward += 0.2 # Small reward for opening a position
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif self.position == 'long':
# Close long and open short
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar
})
# Reward based on PnL
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
reward -= 1.0 # Negative reward for loss
self.loss_count += 1
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Now open short
self.position = 'short'
self.entry_price = self.current_price
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal sell point
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
reward += 2.0 # Bonus for selling at a top
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif action == 3: # CLOSE
if self.position == 'long':
# Close long position
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar
})
# Reward based on PnL
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
reward -= 1.0 # Negative reward for loss
self.loss_count += 1
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
elif self.position == 'short':
# Close short position
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar
})
# Reward based on PnL
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
reward -= 1.0 # Negative reward for loss
self.loss_count += 1
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Add prediction accuracy component to reward
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
# Compare the first prediction with actual price
if len(self.data) > 1:
actual_price = self.data[-1]['close']
predicted_price = self.predicted_prices[0]
prediction_error = abs(predicted_price - actual_price) / actual_price
# Reward accurate predictions, penalize bad ones
if prediction_error < 0.005: # Less than 0.5% error
reward += 0.5
elif prediction_error > 0.02: # More than 2% error
reward -= 0.5
return reward
def is_downtrend(self):
"""Check if the market is in a downtrend"""
if len(self.features['price']) < 20:
return False
# Use EMA to determine trend
short_ema = self.features['ema_9'][-1]
long_ema = self.features['ema_21'][-1]
# Downtrend if short EMA is below long EMA
return short_ema < long_ema
def is_uptrend(self):
"""Check if the market is in an uptrend"""
if len(self.features['price']) < 20:
return False
# Use EMA to determine trend
short_ema = self.features['ema_9'][-1]
long_ema = self.features['ema_21'][-1]
# Uptrend if short EMA is above long EMA
return short_ema > long_ema
def get_market_direction(self):
"""Get the current market direction"""
if self.is_uptrend():
return "uptrend"
elif self.is_downtrend():
return "downtrend"
else:
return "sideways"
def analyze_trades(self):
"""Analyze completed trades to identify patterns"""
if not self.trades:
return {}
analysis = {
'total_trades': len(self.trades),
'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0),
'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0),
'avg_win': 0,
'avg_loss': 0,
'avg_duration': 0,
'uptrend_win_rate': 0,
'downtrend_win_rate': 0,
'sideways_win_rate': 0
}
# Calculate averages
wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0]
losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0]
durations = [t.get('duration', 0) for t in self.trades]
analysis['avg_win'] = sum(wins) / len(wins) if wins else 0
analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0
analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0
# Calculate win rates by market direction
for direction in ['uptrend', 'downtrend', 'sideways']:
direction_trades = [t for t in self.trades if t.get('market_direction') == direction]
if direction_trades:
wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0)
analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100
return analysis
def initialize_price_predictor(self, device="cpu"):
"""Initialize the price prediction model"""
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
self.price_predictor.to(device)
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
self.predicted_prices = np.array([])
def train_price_predictor(self):
"""Train the price prediction model on recent data"""
if len(self.features['price']) < 35:
return 0.0
# Get price history
price_history = self.features['price']
# Train the model
loss = self.price_predictor.train_on_new_data(
price_history,
self.price_predictor_optimizer,
epochs=5
)
return loss
def update_price_predictions(self):
"""Update price predictions"""
if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
self.predicted_prices = np.array([])
return
# Get price history
price_history = self.features['price']
try:
# Get predictions
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
except Exception as e:
logger.warning(f"Error updating predictions: {e}")
self.predicted_prices = np.array([])
def identify_optimal_trades(self):
"""Identify optimal entry and exit points based on local extrema"""
if len(self.features['price']) < 20:
return
# Find local bottoms and tops
bottoms, tops = find_local_extrema(self.features['price'], window=5)
# Store optimal trade points
self.optimal_bottoms = bottoms # Buy points
self.optimal_tops = tops # Sell points
# Create optimal trade signals
self.optimal_signals = np.zeros(len(self.features['price']))
for i in bottoms:
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
self.optimal_signals[i] = 1 # Buy signal
for i in tops:
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
self.optimal_signals[i] = -1 # Sell signal
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
def calculate_position_size(self):
"""Calculate position size based on current balance and risk parameters"""
# Use a fixed percentage of balance for each trade
risk_percent = 5.0 # Risk 5% of balance per trade
# Calculate position size with leverage
position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE
# Apply a safety factor to avoid liquidation
safety_factor = 0.8
position_size *= safety_factor
# Ensure minimum position size
min_position = 10.0 # Minimum position size in USD
position_size = max(position_size, min(min_position, self.balance * 0.5))
# Ensure position size doesn't exceed balance * leverage
max_position = self.balance * MAX_LEVERAGE
position_size = min(position_size, max_position)
return position_size
def calculate_fees(self, position_size):
"""Calculate trading fees for a given position size"""
# Typical fee rate for crypto exchanges (0.1%)
fee_rate = 0.001
# Calculate fee
fee = position_size * fee_rate
return fee
def is_uncertain_market(self):
"""Check if the market is in an uncertain/sideways state"""
if len(self.features['price']) < 20:
return True
# Check if price is within a narrow range
recent_prices = self.features['price'][-20:]
price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices)
# Check if EMAs are close to each other
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
short_ema = self.features['ema_9'][-1]
long_ema = self.features['ema_21'][-1]
ema_diff = abs(short_ema - long_ema) / long_ema
# Return True if price range is small and EMAs are close
return price_range < 0.02 and ema_diff < 0.005
return price_range < 0.015 # Very narrow range
def is_near_support(self):
"""Check if current price is near a support level"""
if not hasattr(self, 'features') or len(self.features['price']) < 30:
return False
# Find recent lows
prices = self.features['price'][-30:]
lows = []
for i in range(1, len(prices)-1):
if prices[i] < prices[i-1] and prices[i] < prices[i+1]:
lows.append(prices[i])
if not lows:
return False
# Check if current price is near any of these lows
current_price = self.current_price
for low in lows:
if abs(current_price - low) / low < 0.01: # Within 1% of a recent low
return True
return False
def is_near_resistance(self):
"""Check if current price is near a resistance level"""
if not hasattr(self, 'features') or len(self.features['price']) < 30:
return False
# Find recent highs
prices = self.features['price'][-30:]
highs = []
for i in range(1, len(prices)-1):
if prices[i] > prices[i-1] and prices[i] > prices[i+1]:
highs.append(prices[i])
if not highs:
return False
# Check if current price is near any of these highs
current_price = self.current_price
for high in highs:
if abs(current_price - high) / high < 0.01: # Within 1% of a recent high
return True
return False
def is_market_turning(self):
"""Check if the market is potentially changing direction"""
if len(self.features['price']) < 20:
return False
# Check for divergence between price and momentum indicators
if len(self.features['rsi']) > 5:
# Price making higher highs but RSI making lower highs (bearish divergence)
price_trend = self.features['price'][-1] > self.features['price'][-5]
rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5]
if price_trend != rsi_trend:
return True
# Check for EMA crossover
if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1:
short_ema_prev = self.features['ema_9'][-2]
long_ema_prev = self.features['ema_21'][-2]
short_ema_curr = self.features['ema_9'][-1]
long_ema_curr = self.features['ema_21'][-1]
# Check if EMAs just crossed
if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \
(short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr):
return True
return False
def is_market_against_position(self, position_type):
"""Check if market conditions have turned against the current position"""
if position_type == 'long':
# For long positions, check if market has turned bearish
return self.is_downtrend() and not self.is_near_support()
elif position_type == 'short':
# For short positions, check if market has turned bullish
return self.is_uptrend() and not self.is_near_resistance()
return False
def is_near_optimal_exit(self, position_type):
"""Check if current price is near an optimal exit point for the position"""
current_idx = len(self.features['price']) - 1
if position_type == 'long' and hasattr(self, 'optimal_tops'):
# For long positions, optimal exit is near tops
for top_idx in self.optimal_tops:
if abs(current_idx - top_idx) < 3: # Within 3 candles of a top
return True
elif position_type == 'short' and hasattr(self, 'optimal_bottoms'):
# For short positions, optimal exit is near bottoms
for bottom_idx in self.optimal_bottoms:
if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom
return True
return False
def calculate_future_profit_potential(self, position_type, lookahead=20):
"""
Calculate potential profit if position is held for a certain period
This is used for retrospective backtesting rewards
Args:
position_type: 'long' or 'short'
lookahead: Number of candles to look ahead
Returns:
Potential profit percentage
"""
if len(self.data) <= 1 or self.current_step >= len(self.data):
return 0
# Get current price
current_price = self.current_price
# Get future prices (if available in historical data)
future_prices = []
current_idx = self.current_step
# Safely get future prices
for i in range(1, min(lookahead + 1, len(self.data) - current_idx)):
if current_idx + i < len(self.data):
future_prices.append(self.data[current_idx + i]['close'])
if not future_prices:
return 0
# Calculate potential profit
if position_type == 'long':
# For long positions, find the maximum price in the future
max_future_price = max(future_prices)
potential_profit = (max_future_price - current_price) / current_price * 100
else: # short
# For short positions, find the minimum price in the future
min_future_price = min(future_prices)
potential_profit = (current_price - min_future_price) / current_price * 100
return potential_profit
async def initialize_futures(self, exchange):
"""Initialize futures trading parameters"""
if not self.demo:
try:
# Set up futures trading parameters
await exchange.set_position_mode(True) # Hedge mode
await exchange.set_margin_mode("cross", symbol=self.futures_symbol)
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
logger.info(f"Futures initialized with {self.leverage}x leverage")
except Exception as e:
logger.error(f"Failed to initialize futures trading: {str(e)}")
logger.info("Falling back to demo mode for safety")
demo = True
async def execute_real_trade(self, exchange, action, current_price):
"""Execute real futures trade on MEXC"""
try:
position_size = self.calculate_position_size()
if action == 1: # Open long
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='buy',
amount=position_size,
params={'positionSide': 'LONG'}
)
logger.info(f"Opened LONG position: {order}")
elif action == 2: # Open short
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='sell',
amount=position_size,
params={'positionSide': 'SHORT'}
)
logger.info(f"Opened SHORT position: {order}")
elif action == 3: # Close position
position_side = 'LONG' if self.position == 'long' else 'SHORT'
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='sell' if position_side == 'LONG' else 'buy',
amount=self.position_size,
params={'positionSide': position_side}
)
logger.info(f"Closed {position_side} position: {order}")
return order
except Exception as e:
logger.error(f"Trade execution failed: {e}")
return None
# Ensure GPU usage if available
def get_device():
"""Get the best available device (CUDA GPU or CPU)"""
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
# Set up for mixed precision training
torch.backends.cudnn.benchmark = True
else:
device = torch.device("cpu")
logger.info("GPU not available, using CPU")
return device
# Update Agent class to use GPU properly
class Agent:
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
"""Initialize Agent with architecture parameters stored as attributes"""
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
# Set device
self.device = device if device is not None else get_device()
# Initialize networks
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
# Initialize optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Initialize replay memory
self.memory = ReplayMemory(MEMORY_SIZE)
# Initialize exploration parameters
self.epsilon = EPSILON_START
self.epsilon_decay = EPSILON_DECAY
self.epsilon_min = EPSILON_END
# Initialize step counter
self.steps_done = 0
# Initialize TensorBoard writer
self.writer = None
# Initialize GradScaler for mixed precision training
self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None
# Rest of the initialization code...
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
"""Expand the model to handle more features or increase capacity"""
logger.info(f"Expanding model: {self.state_size}{new_state_size}, "
f"hidden: {self.policy_net.hidden_size}{new_hidden_size}")
# Save old weights
old_state_dict = self.policy_net.state_dict()
# Create new larger networks
new_policy_net = DQN(new_state_size, self.action_size,
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
new_target_net = DQN(new_state_size, self.action_size,
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
# Transfer weights for common layers
new_state_dict = new_policy_net.state_dict()
for name, param in old_state_dict.items():
if name in new_state_dict:
# If shapes match, copy directly
if new_state_dict[name].shape == param.shape:
new_state_dict[name] = param
# For first layer, copy weights for the original input dimensions
elif name == "fc1.weight":
new_state_dict[name][:, :self.state_size] = param
# For other layers, initialize with a strategy that preserves scale
else:
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
# Load transferred weights
new_policy_net.load_state_dict(new_state_dict)
new_target_net.load_state_dict(new_state_dict)
# Replace networks
self.policy_net = new_policy_net
self.target_net = new_target_net
self.target_net.eval()
# Update optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Update state size
self.state_size = new_state_size
# Print new model size
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"New model size: {total_params:,} parameters")
return True
def select_action(self, state, training=True):
sample = random.random()
if training:
# More aggressive epsilon decay for faster exploitation
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
np.exp(-1.5 * self.steps_done / EPSILON_DECAY) # Increased decay factor
self.steps_done += 1
# Lower threshold for exploration, especially in live trading
if not training:
# In live trading, be much more aggressive with exploitation
self.epsilon = max(EPSILON_END, self.epsilon * 0.95)
if sample > self.epsilon or not training:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
action_values = self.policy_net(state_tensor)
# Add temperature-based sampling for more aggressive actions
# when the model is confident (higher action differences)
if not training: # More aggressive in live trading
values = action_values.cpu().numpy()
max_value = np.max(values)
value_diff = max_value - np.mean(values)
# If there's a clear best action, always take it
if value_diff > 0.5:
return action_values.max(1)[1].item()
return action_values.max(1)[1].item()
else:
return random.randrange(self.action_size)
def learn(self):
"""Learn from a batch of experiences"""
if len(self.memory) < BATCH_SIZE:
return None
try:
# Sample a batch of experiences
experiences = self.memory.sample(BATCH_SIZE)
# Convert experiences to tensors
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
# Use mixed precision for forward/backward passes
if self.device.type == "cuda" and self.scaler is not None:
with torch.amp.autocast('cuda'):
# Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Compute next Q values with target network
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
# Reshape target values to match current_q_values
target_q_values = target_q_values.unsqueeze(1)
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Backward pass with mixed precision
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
# Gradient clipping to prevent exploding gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard precision for CPU
# Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Compute next Q values with target network
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
# Reshape target values to match current_q_values
target_q_values = target_q_values.unsqueeze(1)
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
# Update steps done
self.steps_done += 1
# Update target network
if self.steps_done % TARGET_UPDATE == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
return loss.item()
except Exception as e:
logger.error(f"Error during learning: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
return None
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path="models/trading_agent_best_pnl.pt"):
"""Save the model in a format compatible with PyTorch 2.6+"""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
# Ensure architecture parameters are set
if not hasattr(self, 'hidden_size'):
self.hidden_size = 256 # Default value
logger.warning("Setting default hidden_size=256 for saving")
if not hasattr(self, 'lstm_layers'):
self.lstm_layers = 2 # Default value
logger.warning("Setting default lstm_layers=2 for saving")
if not hasattr(self, 'attention_heads'):
self.attention_heads = 4 # Default value
logger.warning("Setting default attention_heads=4 for saving")
# Save model state
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'state_size': self.state_size,
'action_size': self.action_size,
'hidden_size': self.hidden_size,
'lstm_layers': self.lstm_layers,
'attention_heads': self.attention_heads
}
# Save scaler state if it exists
if hasattr(self, 'scaler') and self.scaler is not None:
checkpoint['scaler'] = self.scaler.state_dict()
# Save with pickle_protocol=4 for better compatibility
torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4)
logger.info(f"Model saved to {path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
import traceback
logger.error(traceback.format_exc())
def load(self, path="models/trading_agent_best_pnl.pt"):
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
try:
# First try to load with weights_only=False (for models saved with older PyTorch versions)
try:
logger.info(f"Attempting to load model with weights_only=False: {path}")
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
logger.info("Model loaded successfully with weights_only=False")
except Exception as e1:
logger.warning(f"Failed to load with weights_only=False: {e1}")
# Try with safe_globals context manager
try:
logger.info("Attempting to load with safe_globals context manager")
import numpy as np
from torch.serialization import safe_globals
# Add numpy scalar to safe globals
with safe_globals(['numpy._core.multiarray.scalar']):
checkpoint = torch.load(path, map_location=self.device)
logger.info("Model loaded successfully with safe_globals")
except Exception as e2:
logger.warning(f"Failed to load with safe_globals: {e2}")
# Last resort: try with pickle_module=pickle
logger.info("Attempting to load with pickle_module")
import pickle
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
logger.info("Model loaded successfully with pickle_module")
# Load state dictionaries
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
# Try to load optimizer state
try:
self.optimizer.load_state_dict(checkpoint['optimizer'])
except Exception as e:
logger.warning(f"Could not load optimizer state: {e}")
# Load epsilon if available
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon']
# Load architecture parameters if available
if 'state_size' in checkpoint:
self.state_size = checkpoint['state_size']
if 'action_size' in checkpoint:
self.action_size = checkpoint['action_size']
if 'hidden_size' in checkpoint:
self.hidden_size = checkpoint['hidden_size']
else:
# If hidden_size not in checkpoint, infer from model
try:
self.hidden_size = self.policy_net.fc1.weight.shape[0]
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
except:
self.hidden_size = 256 # Default value
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
if 'lstm_layers' in checkpoint:
self.lstm_layers = checkpoint['lstm_layers']
else:
self.lstm_layers = 2 # Default value
if 'attention_heads' in checkpoint:
self.attention_heads = checkpoint['attention_heads']
else:
self.attention_heads = 4 # Default value
logger.info(f"Model loaded successfully from {path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def add_chart_to_tensorboard(self, env, global_step):
"""Add trading chart to TensorBoard"""
try:
if len(env.data) < 10:
return
# Create chart image
chart_img = create_candlestick_figure(
env.data,
env.trade_signals,
window_size=100,
title=f"Trading Chart - Step {global_step}"
)
if chart_img is not None:
# Convert PIL image to numpy array for TensorBoard
chart_array = np.array(chart_img)
# TensorBoard expects [C, H, W] format
chart_array = np.transpose(chart_array, (2, 0, 1))
self.writer.add_image('Trading Chart', chart_array, global_step)
# Add position information as text
entry_price = env.entry_price if env.entry_price else 0.00
position_info = f"""
**Current Position**: {env.position.upper()}
**Entry Price**: ${entry_price:.2f}
**Current Price**: ${env.data[-1]['close']:.2f}
**Position Size**: ${env.position_size:.2f}
**Unrealized PnL**: ${env.total_pnl:.2f}
"""
self.writer.add_text('Position', position_info, global_step)
except Exception as e:
logger.error(f"Error adding chart to TensorBoard: {str(e)}")
# Continue without visualization rather than crashing
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
"""Get live price data using websockets"""
# Connect to MEXC websocket
uri = "wss://stream.mexc.com/ws"
async with websockets.connect(uri) as websocket:
# Subscribe to kline data
subscribe_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
}
await websocket.send(json.dumps(subscribe_msg))
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
while True:
try:
response = await websocket.recv()
data = json.loads(response)
if 'data' in data:
kline = data['data']
candle = {
'timestamp': kline['t'],
'open': float(kline['o']),
'high': float(kline['h']),
'low': float(kline['l']),
'close': float(kline['c']),
'volume': float(kline['v'])
}
yield candle
except Exception as e:
logger.error(f"Websocket error: {e}")
# Try to reconnect
await asyncio.sleep(5)
break
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
"""Train the agent using historical and live data with GPU acceleration"""
# Initialize statistics tracking
stats = {
'episode_rewards': [],
'episode_lengths': [],
'balances': [],
'win_rates': [],
'episode_pnls': [],
'cumulative_pnl': [],
'drawdowns': [],
'prediction_accuracy': [],
'trade_analysis': []
}
# Track best models
best_reward = float('-inf')
best_pnl = float('-inf')
# Initialize TensorBoard writer if not already initialized
if not hasattr(agent, 'writer') or agent.writer is None:
agent.writer = SummaryWriter('runs/training')
# Training loop
for episode in range(num_episodes):
try:
# Reset environment
state = env.reset()
episode_reward = 0
prediction_loss = 0
# Episode loop
for step in range(max_steps_per_episode):
# Select action
action = agent.select_action(state)
# Take action
try:
next_state, reward, done, info = env.step(action)
except Exception as e:
logger.error(f"Error in step function: {e}")
break
# Store transition in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
# Update episode reward
episode_reward += reward
# Learn from experience
if len(agent.memory) > BATCH_SIZE:
agent.learn()
# Update price predictions periodically
if step % 50 == 0:
try:
env.update_price_predictions()
env.identify_optimal_trades()
except Exception as e:
logger.warning(f"Error updating predictions: {e}")
# Add chart to TensorBoard periodically
if step % 50 == 0 or (step == max_steps_per_episode - 1) or done:
try:
global_step = episode * max_steps_per_episode + step
agent.add_chart_to_tensorboard(env, global_step)
except Exception as e:
logger.warning(f"Error adding chart to TensorBoard: {e}")
# End episode if done
if done:
break
# Update target network periodically
if episode % TARGET_UPDATE == 0:
agent.update_target_network()
# Calculate win rate
total_trades = env.win_count + env.loss_count
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
# Train price predictor
try:
if episode % 5 == 0 and len(env.data) > 50:
prediction_loss = env.train_price_predictor()
except Exception as e:
logger.warning(f"Error training price predictor: {e}")
prediction_loss = 0
# Analyze trades
try:
trade_analysis = env.analyze_trades()
stats['trade_analysis'].append(trade_analysis)
except Exception as e:
logger.warning(f"Error analyzing trades: {e}")
trade_analysis = {}
stats['trade_analysis'].append({})
# Calculate prediction accuracy
prediction_accuracy = 0.0
try:
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
if len(env.data) > 5:
actual_prices = [candle['close'] for candle in env.data[-5:]]
predicted = env.predicted_prices[:min(5, len(actual_prices))]
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
except Exception as e:
logger.warning(f"Error calculating prediction accuracy: {e}")
# Log statistics
stats['episode_rewards'].append(episode_reward)
stats['episode_lengths'].append(step + 1)
stats['balances'].append(env.balance)
stats['win_rates'].append(win_rate)
stats['episode_pnls'].append(env.episode_pnl)
stats['cumulative_pnl'].append(env.total_pnl)
stats['drawdowns'].append(env.max_drawdown * 100)
stats['prediction_accuracy'].append(prediction_accuracy)
# Log detailed trade analysis
if trade_analysis:
logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | "
f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}")
# Log to TensorBoard
agent.writer.add_scalar('Reward/train', episode_reward, episode)
agent.writer.add_scalar('Balance/train', env.balance, episode)
agent.writer.add_scalar('WinRate/train', win_rate, episode)
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
# Add final chart for this episode
try:
agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode)
except Exception as e:
logger.warning(f"Error adding final chart: {e}")
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%")
# Save best model by reward
if episode_reward > best_reward:
best_reward = episode_reward
agent.save("models/trading_agent_best_reward.pt")
# Save best model by PnL
if env.episode_pnl > best_pnl:
best_pnl = env.episode_pnl
agent.save("models/trading_agent_best_pnl.pt")
# Save checkpoint
if episode % 10 == 0:
agent.save(f"models/trading_agent_episode_{episode}.pt")
except Exception as e:
logger.error(f"Error in episode {episode}: {e}")
continue
# Save final model
agent.save("models/trading_agent_final.pt")
# Plot training results
plot_training_results(stats)
return stats
def plot_training_results(stats):
"""Plot detailed training results"""
plt.figure(figsize=(20, 15))
# Plot rewards
plt.subplot(3, 2, 1)
plt.plot(stats['episode_rewards'])
plt.title('Episode Rewards')
plt.xlabel('Episode')
plt.ylabel('Reward')
# Plot balance
plt.subplot(3, 2, 2)
plt.plot(stats['balances'])
plt.title('Account Balance')
plt.xlabel('Episode')
plt.ylabel('Balance ($)')
# Plot win rate
plt.subplot(3, 2, 3)
plt.plot(stats['win_rates'])
plt.title('Win Rate')
plt.xlabel('Episode')
plt.ylabel('Win Rate (%)')
# Plot episode PnL
plt.subplot(3, 2, 4)
plt.plot(stats['episode_pnls'])
plt.title('Episode PnL')
plt.xlabel('Episode')
plt.ylabel('PnL ($)')
# Plot cumulative PnL
plt.subplot(3, 2, 5)
plt.plot(stats['cumulative_pnl'])
plt.title('Cumulative PnL')
plt.xlabel('Episode')
plt.ylabel('Cumulative PnL ($)')
# Plot drawdown
plt.subplot(3, 2, 6)
plt.plot(stats['drawdowns'])
plt.title('Maximum Drawdown')
plt.xlabel('Episode')
plt.ylabel('Drawdown (%)')
plt.tight_layout()
plt.savefig('training_results.png')
# Save statistics to CSV
df = pd.DataFrame(stats)
df.to_csv('training_stats.csv', index=False)
logger.info("Training statistics saved to training_stats.csv and training_results.png")
def evaluate_agent(agent, env, num_episodes=10):
"""Evaluate the agent on test data"""
total_reward = 0
total_profit = 0
total_trades = 0
winning_trades = 0
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
initial_balance = env.balance
done = False
while not done:
# Select action (no exploration)
action = agent.select_action(state, training=False)
next_state, reward, done, info = env.step(action)
state = next_state
episode_reward += reward
total_reward += episode_reward
total_profit += env.balance - initial_balance
# Count trades and wins
for trade in env.trades:
if 'pnl_percent' in trade:
total_trades += 1
if trade['pnl_percent'] > 0:
winning_trades += 1
# Calculate averages
avg_reward = total_reward / num_episodes
avg_profit = total_profit / num_episodes
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
f"Win Rate={win_rate:.1f}%")
return avg_reward, avg_profit, win_rate
async def test_training():
"""Test the training process with a small number of episodes"""
logger.info("Starting training tests...")
# Initialize exchange
exchange = ccxt.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True,
})
try:
# Create environment with small initial balance for testing
env = TradingEnvironment(
exchange=exchange,
symbol="ETH/USDT",
timeframe="1m",
leverage=MAX_LEVERAGE,
initial_balance=100, # Small balance for testing
demo=True # Always use demo mode for testing
)
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
# Run a few test episodes
test_episodes = 3
logger.info(f"Running {test_episodes} test episodes...")
for episode in range(test_episodes):
state = env.reset()
episode_reward = 0
done = False
step = 0
while not done and step < 100: # Limit steps for testing
# Select action
action = agent.select_action(state)
# Take action
next_state, reward, done, info = env.step(action)
# Store experience
agent.memory.push(state, action, reward, next_state, done)
# Learn
loss = agent.learn()
state = next_state
episode_reward += reward
step += 1
# Print progress
if step % 10 == 0:
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
# Test model saving
try:
agent.save("models/test_model.pt")
logger.info("Successfully saved model")
except Exception as e:
logger.error(f"Error saving model: {e}")
logger.info("Training tests completed successfully")
return True
except Exception as e:
logger.error(f"Training test failed: {e}")
return False
finally:
await exchange.close()
async def initialize_exchange():
"""Initialize the exchange connection"""
try:
# Try to initialize with async support first
try:
exchange = ccxt.pro.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True
})
logger.info(f"Exchange initialized with async support: {exchange.id}")
except (AttributeError, ImportError):
# Fall back to standard CCXT
exchange = ccxt.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True
})
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
return exchange
except Exception as e:
logger.error(f"Failed to initialize exchange: {e}")
raise
async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""Fetch historical OHLCV data from the exchange"""
try:
logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}")
# Use the refactored fetch method
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
if not data:
logger.warning("No historical data received")
return data
except Exception as e:
logger.error(f"Failed to fetch historical data: {e}")
return []
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
"""Run the trading bot in live mode with enhanced error handling"""
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
# Verify agent is properly initialized
try:
# Ensure agent has all required attributes
if not hasattr(agent, 'hidden_size'):
agent.hidden_size = 256 # Default value
logger.warning("Agent missing hidden_size attribute, using default: 256")
if not hasattr(agent, 'lstm_layers'):
agent.lstm_layers = 2 # Default value
logger.warning("Agent missing lstm_layers attribute, using default: 2")
if not hasattr(agent, 'attention_heads'):
agent.attention_heads = 4 # Default value
logger.warning("Agent missing attention_heads attribute, using default: 4")
logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}")
except Exception as e:
logger.error(f"Error checking agent configuration: {e}")
# Continue anyway, as these are just informational attributes
if not demo:
# Confirm with user before starting live trading
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
if confirmation != "CONFIRM":
logger.info("Live trading canceled by user")
return
# Initialize futures trading if not in demo mode
try:
await env.initialize_futures(exchange)
logger.info(f"Futures trading initialized with {leverage}x leverage")
except Exception as e:
logger.error(f"Failed to initialize futures trading: {str(e)}")
logger.info("Falling back to demo mode for safety")
demo = True
# Initialize TensorBoard for monitoring
if not hasattr(agent, 'writer') or agent.writer is None:
from torch.utils.tensorboard import SummaryWriter
# Fix the datetime usage here
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}')
# Track performance metrics
trades_count = 0
winning_trades = 0
total_profit = 0
max_drawdown = 0
peak_balance = env.balance
step_counter = 0
prev_position = 'flat'
# Create directory for trade logs
os.makedirs('trade_logs', exist_ok=True)
# Fix the datetime usage here
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trade_log_path = f'trade_logs/trades_{current_time}.csv'
with open(trade_log_path, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n")
logger.info("Entering live trading loop...")
try:
while True:
try:
# Fetch latest candle data
candle = await get_latest_candle(exchange, symbol)
if candle is None:
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
await asyncio.sleep(5)
continue
# Add new data to environment
env.add_data(candle)
# Get current state and select action
state = env.get_state()
# Verify state shape matches agent's expected input
if state.shape[0] != agent.state_size:
logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}")
# Pad or truncate state to match expected size
if state.shape[0] < agent.state_size:
state = np.pad(state, (0, agent.state_size - state.shape[0]))
else:
state = state[:agent.state_size]
action = agent.select_action(state, training=False)
# Ensure action is valid
if action >= agent.action_size:
logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}")
action = agent.action_size - 1
# Log action
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}")
# Execute action
if not demo:
# Execute real trade on exchange
current_price = env.data[-1]['close']
trade_result = await env.execute_real_trade(exchange, action, current_price)
if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False):
error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed'
logger.error(f"Trade execution failed: {error_msg}")
# Continue with simulated trade for tracking purposes
# Update environment with action (simulated in demo mode)
try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
# Handle case where step returns 3 values instead of 4
if "not enough values to unpack" in str(e):
logger.warning("Step function returned 3 values instead of 4, creating info dict")
next_state, reward, done = env.step(action)
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': env.current_price,
'balance': env.balance,
'position': env.position,
'pnl': env.total_pnl
}
else:
raise
# Log trade if position changed
if env.position != prev_position:
trades_count += 1
if env.last_trade_profit > 0:
winning_trades += 1
total_profit += env.last_trade_profit
# Log trade details
with open(trade_log_path, 'a') as f:
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
# Update performance metrics
if env.balance > peak_balance:
peak_balance = env.balance
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
if current_drawdown > max_drawdown:
max_drawdown = current_drawdown
# Update TensorBoard metrics
step_counter += 1
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
# Update chart visualization
if step_counter % 5 == 0 or env.position != prev_position:
agent.add_chart_to_tensorboard(env, step_counter)
# Log performance summary
if trades_count > 0:
win_rate = (winning_trades / trades_count) * 100
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
performance_text = f"""
**Live Trading Performance**
Balance: ${env.balance:.2f}
Total PnL: ${env.total_pnl:.2f}
Trades: {trades_count}
Win Rate: {win_rate:.1f}%
Max Drawdown: {max_drawdown*100:.1f}%
"""
agent.writer.add_text('Performance', performance_text, step_counter)
prev_position = env.position
# Wait for next candle
logger.info(f"Waiting for next candle... (Step {step_counter})")
await asyncio.sleep(10) # Check every 10 seconds
except Exception as e:
logger.error(f"Error in live trading loop: {str(e)}")
import traceback
logger.error(traceback.format_exc())
logger.info("Continuing after error...")
await asyncio.sleep(30) # Wait longer after an error
except KeyboardInterrupt:
logger.info("Live trading stopped by user")
# Final performance report
if trades_count > 0:
win_rate = (winning_trades / trades_count) * 100
logger.info(f"Trading session summary:")
logger.info(f"Total trades: {trades_count}")
logger.info(f"Win rate: {win_rate:.1f}%")
logger.info(f"Final balance: ${env.balance:.2f}")
logger.info(f"Total profit: ${total_profit:.2f}")
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
logger.info(f"Trade log saved to: {trade_log_path}")
async def get_latest_candle(exchange, symbol):
"""Get the latest candle data"""
try:
# Use the refactored fetch method with limit=1
data = await fetch_ohlcv_data(exchange, symbol, "1m", 1)
if data and len(data) > 0:
return data[0]
else:
logger.warning("No candle data received")
return None
except Exception as e:
logger.error(f"Failed to fetch latest candle: {e}")
return None
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
"""Fetch OHLCV data with proper handling for both async and standard CCXT"""
try:
# Check if exchange has fetchOHLCV method
if not hasattr(exchange, 'fetchOHLCV'):
logger.error("Exchange does not support OHLCV data fetching")
return []
# Handle different CCXT versions
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
# Use async method if available
ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit)
else:
# Use synchronous method with run_in_executor
loop = asyncio.get_event_loop()
ohlcv = await loop.run_in_executor(
None,
lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
)
# Convert to list of dictionaries
data = []
for candle in ohlcv:
timestamp, open_price, high, low, close, volume = candle
data.append({
'timestamp': timestamp,
'open': open_price,
'high': high,
'low': low,
'close': close,
'volume': volume
})
logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})")
return data
except Exception as e:
logger.error(f"Failed to fetch OHLCV data: {e}")
return []
async def initialize_websocket_data_stream(symbol="ETH/USDT", timeframe="1m"):
"""Initialize a WebSocket connection for real-time trading data
Args:
symbol: Trading pair symbol (e.g., "ETH/USDT")
timeframe: Timeframe for candle aggregation (e.g., "1m")
Returns:
Tuple of (websocket, candle_data) where websocket is the BinanceWebSocket instance
and candle_data is a dict to track ongoing candle formation
"""
try:
# Initialize historical data handler to get initial data
historical_data = BinanceHistoricalData()
# Convert timeframe to seconds for historical data
if timeframe == "1m":
interval_seconds = 60
elif timeframe == "5m":
interval_seconds = 300
elif timeframe == "15m":
interval_seconds = 900
elif timeframe == "1h":
interval_seconds = 3600
else:
interval_seconds = 60 # Default to 1m
# Fetch initial historical data
initial_data = historical_data.get_historical_candles(
symbol=symbol,
interval_seconds=interval_seconds,
limit=1000 # Get 1000 candles for good history
)
# Convert pandas DataFrame to list of dictionaries for our environment
initial_candles = []
if not initial_data.empty:
for _, row in initial_data.iterrows():
candle = {
'timestamp': int(row['timestamp'].timestamp() * 1000),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume'])
}
initial_candles.append(candle)
logger.info(f"Loaded {len(initial_candles)} historical candles")
else:
logger.warning("No historical data fetched")
# Initialize WebSocket for real-time data
binance_ws = BinanceWebSocket(symbol.replace('/', ''))
await binance_ws.connect()
# Track the current candle data
current_minute = None
current_candle = None
logger.info(f"WebSocket for {symbol} initialized successfully")
return binance_ws, initial_candles
except Exception as e:
logger.error(f"Failed to initialize WebSocket data stream: {e}")
logger.error(traceback.format_exc())
return None, []
async def process_websocket_ticks(websocket, env, agent=None, demo=True, timeframe="1m"):
"""Process real-time ticks from WebSocket and aggregate them into candles
Args:
websocket: BinanceWebSocket instance
env: TradingEnvironment instance
agent: Agent instance (optional, for live trading)
demo: Whether to run in demo mode
timeframe: Timeframe for candle aggregation
"""
# Initialize variables for candle aggregation
current_candle = None
current_minute = None
trades_count = 0
step_counter = 0
try:
logger.info("Starting WebSocket tick processing...")
while websocket.running:
# Get the next tick from WebSocket
tick = await websocket.receive()
if tick is None:
# No data received, wait and try again
await asyncio.sleep(0.1)
continue
# Extract data from tick
timestamp = tick.get('timestamp')
price = tick.get('price')
volume = tick.get('volume')
if timestamp is None or price is None:
logger.warning(f"Invalid tick data received: {tick}")
continue
# Convert timestamp to datetime
tick_time = datetime.datetime.fromtimestamp(timestamp / 1000)
# For 1-minute candles, track the minute
if timeframe == "1m":
tick_minute = tick_time.replace(second=0, microsecond=0)
# If this is a new minute, close the current candle and start a new one
if current_minute is None or tick_minute > current_minute:
# If there was a previous candle, add it to the environment
if current_candle is not None:
# Add the candle to the environment
env.add_data(current_candle)
# Process trading decisions if agent is provided
if agent is not None:
state = env.get_state()
action = agent.select_action(state, training=False)
# Execute action in environment
next_state, reward, done, info = env.step(action)
# Log trading activity
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
logger.info(f"Step {step_counter}: Action {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}")
step_counter += 1
# Start a new candle
current_minute = tick_minute
current_candle = {
'timestamp': int(current_minute.timestamp() * 1000),
'open': price,
'high': price,
'low': price,
'close': price,
'volume': volume
}
logger.debug(f"Started new candle at {current_minute}")
else:
# Update the current candle
current_candle['high'] = max(current_candle['high'], price)
current_candle['low'] = min(current_candle['low'], price)
current_candle['close'] = price
current_candle['volume'] += volume
# For other timeframes, implement similar logic
# ...
except asyncio.CancelledError:
logger.info("WebSocket processing canceled")
except Exception as e:
logger.error(f"Error in WebSocket tick processing: {e}")
logger.error(traceback.format_exc())
finally:
# Make sure to close the WebSocket
if websocket:
await websocket.close()
logger.info("WebSocket connection closed")
# Add this near the top of the file, after imports
def ensure_pytorch_compatibility():
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
try:
import torch
from torch.serialization import add_safe_globals
import numpy as np
# Add numpy scalar to safe globals for PyTorch 2.6+
add_safe_globals(['numpy._core.multiarray.scalar'])
logger.info("Added numpy scalar to PyTorch safe globals")
except (ImportError, AttributeError) as e:
logger.warning(f"Could not configure PyTorch compatibility: {e}")
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
# Call this function at the start of the main function
async def main():
# Ensure PyTorch compatibility
ensure_pytorch_compatibility()
parser = argparse.ArgumentParser(description='Trading Bot')
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
help='Operation mode: train, eval, or live')
parser.add_argument('--episodes', type=int, default=1000,
help='Number of episodes for training or evaluation')
parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true',
help='Run in demo mode (paper trading) if true')
parser.add_argument('--symbol', type=str, default='ETH/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m',
help='Candle timeframe (1m, 5m, 15m, 1h, etc.)')
parser.add_argument('--leverage', type=int, default=50,
help='Leverage for futures trading')
parser.add_argument('--model', type=str, default=None,
help='Path to model file for evaluation or live trading')
parser.add_argument('--use-websocket', action='store_true',
help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)')
parser.add_argument('--dashboard', action='store_true',
help='Enable Dash dashboard visualization for real-time trading')
args = parser.parse_args()
# Convert string boolean to actual boolean
demo_mode = args.demo.lower() == 'true'
# Get device (GPU or CPU)
device = get_device()
exchange = None
try:
# Initialize exchange
exchange = await initialize_exchange()
# Create environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
if args.mode == 'train':
# Fetch initial data for training
await env.fetch_initial_data(exchange, args.symbol,args.timeframe, 1000)
# Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Train the agent
logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'eval' or args.mode == 'live':
# Fetch initial data for the specified symbol and timeframe
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Determine model path
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
# Create agent with default parameters
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Try to load the model
try:
# Add numpy scalar to safe globals before loading
import numpy as np
from torch.serialization import add_safe_globals
# Add numpy scalar to safe globals
add_safe_globals(['numpy._core.multiarray.scalar'])
# Load the model
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# Ask user if they want to continue with a new model
if args.mode == 'live':
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
if confirmation.lower() != 'y':
logger.info("Live trading canceled by user")
return
logger.info("Continuing with a new model")
else:
logger.info("Continuing evaluation with a new model")
if args.mode == 'eval':
# Evaluate the agent
logger.info("Evaluating agent...")
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'live':
# Start live trading
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
if args.use_websocket:
logger.info("Using Binance WebSocket for real-time data")
await live_trading_with_websocket(
agent=agent,
env=env,
symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage,
use_dashboard=args.dashboard
)
else:
logger.info("Using CCXT for real-time data")
await live_trading(
agent=agent,
env=env,
exchange=exchange,
symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage
)
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
# Clean up exchange connection
if exchange:
try:
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
await exchange.client.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.warning(f"Could not properly close exchange connection: {e}")
# Add this function near the top with other utility functions
def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
"""Create a candlestick chart with trade signals for TensorBoard visualization"""
if len(data) < 10:
return None
try:
# Create figure
fig = plt.figure(figsize=(12, 8))
# Prepare data for plotting
df = pd.DataFrame(data[-window_size:])
df['date'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('date', inplace=True)
# Create subplot grid
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
price_ax = plt.subplot(gs[0])
volume_ax = plt.subplot(gs[1], sharex=price_ax)
# Plot candlesticks - use a simpler approach if mplfinance fails
try:
# Use a different style or approach that doesn't use 'type' parameter
mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo')
except Exception as e:
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
# Fallback to simple plot
price_ax.plot(df.index, df['close'], label='Price')
volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5)
# Add trade signals
for signal in trade_signals:
try:
timestamp = pd.to_datetime(signal['timestamp'], unit='ms')
price = signal['price']
if signal['type'] == 'buy':
price_ax.plot(timestamp, price, '^', color='green', markersize=10)
elif signal['type'] == 'sell':
price_ax.plot(timestamp, price, 'v', color='red', markersize=10)
elif signal['type'] == 'close_long':
price_ax.plot(timestamp, price, 'x', color='gold', markersize=10)
elif signal['type'] == 'close_short':
price_ax.plot(timestamp, price, 'x', color='black', markersize=10)
elif 'stop_loss' in signal['type']:
price_ax.plot(timestamp, price, 'X', color='purple', markersize=10)
elif 'take_profit' in signal['type']:
price_ax.plot(timestamp, price, '*', color='cyan', markersize=10)
except Exception as e:
logger.warning(f"Error plotting signal: {e}")
continue
# Add balance and PnL annotation
if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]:
balance = trade_signals[-1]['balance']
pnl = trade_signals[-1]['pnl']
price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}",
xy=(0.02, 0.95), xycoords='axes fraction',
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
# Set title and format
price_ax.set_title(title)
fig.tight_layout()
# Convert to image
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
plt.close(fig)
img = Image.open(buf)
return img
except Exception as e:
logger.error(f"Error creating chart: {str(e)}")
return None
async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50, use_dashboard=False):
"""Run the trading bot in live mode using Binance WebSocket for real-time data
Args:
agent: The trading agent to use for decision making
env: The trading environment
symbol: The trading pair symbol (e.g., "ETH/USDT")
timeframe: The candlestick timeframe (e.g., "1m")
demo: Whether to run in demo mode (paper trading)
leverage: The leverage to use for trading
use_dashboard: Whether to display the real-time dashboard
Returns:
None
"""
logger.info(f"Starting live trading with WebSocket for {symbol} on {timeframe} timeframe")
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
# If not demo mode, confirm with user before starting live trading
if not demo:
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
if confirmation != "CONFIRM":
logger.info("Live trading canceled by user")
return
# Initialize TensorBoard for monitoring
if not hasattr(agent, 'writer') or agent.writer is None:
from torch.utils.tensorboard import SummaryWriter
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}')
# Initialize Dash dashboard if enabled
dashboard = None
if use_dashboard:
try:
dashboard = TradingDashboard(symbol)
dashboard_started = dashboard.start() # Start the dashboard in a separate thread
if dashboard_started:
logger.info(f"Trading dashboard enabled at http://localhost:8060")
else:
logger.warning("Failed to start trading dashboard, continuing without visualization")
dashboard = None
except Exception as e:
logger.error(f"Error initializing dashboard: {e}")
logger.error(traceback.format_exc())
dashboard = None
# Track performance metrics
trades_count = 0
winning_trades = 0
total_profit = 0
max_drawdown = 0
peak_balance = env.balance
step_counter = 0
# Create directory for trade logs
os.makedirs('trade_logs', exist_ok=True)
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trade_log_path = f'trade_logs/trades_ws_{current_time}.csv'
with open(trade_log_path, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n")
try:
# Initialize WebSocket connection and get historical data
websocket, initial_candles = await initialize_websocket_data_stream(symbol, timeframe)
if websocket is None or not initial_candles:
logger.error("Failed to initialize WebSocket data stream")
return
# Load initial historical data into the environment
logger.info(f"Loading {len(initial_candles)} initial candles into environment")
for candle in initial_candles:
env.add_data(candle)
# Reset environment with historical data
env.reset()
# Update dashboard with initial data if enabled
if dashboard:
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
# Initialize futures trading if not in demo mode
exchange = None
if not demo:
# Import ccxt for exchange initialization
import ccxt.async_support as ccxt_async
# Initialize exchange for order execution
exchange = await initialize_exchange()
if exchange:
try:
await env.initialize_futures(exchange)
logger.info(f"Futures trading initialized with {leverage}x leverage")
except Exception as e:
logger.error(f"Failed to initialize futures trading: {str(e)}")
logger.info("Falling back to demo mode for safety")
demo = True
# Start WebSocket processing in the background
websocket_task = asyncio.create_task(
process_websocket_ticks(websocket, env, agent, demo, timeframe)
)
# Main tracking loop
prev_position = 'flat'
while True:
try:
# Check if position has changed
if env.position != prev_position:
trades_count += 1
if hasattr(env, 'last_trade_profit') and env.last_trade_profit > 0:
winning_trades += 1
if hasattr(env, 'last_trade_profit'):
total_profit += env.last_trade_profit
# Log trade details
current_time = datetime.datetime.now().isoformat()
action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE"
with open(trade_log_path, 'a') as f:
f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n")
logger.info(f"Trade executed: {action_name} at ${env.current_price:.2f}, PnL: ${getattr(env, 'last_trade_profit', 0):.2f}")
# Update performance metrics
if env.balance > peak_balance:
peak_balance = env.balance
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
if current_drawdown > max_drawdown:
max_drawdown = current_drawdown
# Update TensorBoard metrics
step_counter += 1
if step_counter % 10 == 0: # Update every 10 steps
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
# Update chart visualization
if step_counter % 30 == 0 or env.position != prev_position:
agent.add_chart_to_tensorboard(env, step_counter)
# Log performance summary
if trades_count > 0:
win_rate = (winning_trades / trades_count) * 100
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
performance_text = f"""
**Live Trading Performance**
Balance: ${env.balance:.2f}
Total PnL: ${env.total_pnl:.2f}
Trades: {trades_count}
Win Rate: {win_rate:.1f}%
Max Drawdown: {max_drawdown*100:.1f}%
"""
agent.writer.add_text('Performance', performance_text, step_counter)
# Update the dashboard with latest data if enabled
if dashboard:
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
prev_position = env.position
# Sleep for a short time to prevent CPU hogging
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error in live trading monitor loop: {str(e)}")
logger.error(traceback.format_exc())
await asyncio.sleep(10) # Wait longer after an error
except KeyboardInterrupt:
logger.info("Live trading stopped by user")
# Cancel the WebSocket task
if 'websocket_task' in locals() and not websocket_task.done():
websocket_task.cancel()
try:
await websocket_task
except asyncio.CancelledError:
pass
# Close the exchange connection if it exists
if exchange:
await exchange.close()
# Final performance report
if trades_count > 0:
win_rate = (winning_trades / trades_count) * 100
logger.info(f"Trading session summary:")
logger.info(f"Total trades: {trades_count}")
logger.info(f"Win rate: {win_rate:.1f}%")
logger.info(f"Final balance: ${env.balance:.2f}")
logger.info(f"Total profit: ${total_profit:.2f}")
except Exception as e:
logger.error(f"Critical error in live trading: {str(e)}")
logger.error(traceback.format_exc())
finally:
# Make sure to close WebSocket
if 'websocket' in locals() and websocket:
await websocket.close()
# Close the exchange connection if it exists
if 'exchange' in locals() and exchange:
await exchange.close()
def ensure_pytorch_compatibility():
"""Check and fix common PyTorch compatibility issues"""
try:
import torch.serialization
import pickle
# Register safe pickles to handle the numpy scalar warning
if hasattr(torch.serialization, 'add_safe_globals'):
torch.serialization.add_safe_globals([('numpy._core.multiarray.scalar', np.ndarray)])
torch.serialization.add_safe_globals([('numpy.core.multiarray.scalar', np.ndarray)])
torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])
torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar'])
logger.info("PyTorch safe globals registered for compatibility")
else:
logger.warning("PyTorch serialization module doesn't have add_safe_globals method")
except Exception as e:
logger.warning(f"PyTorch compatibility check failed: {e}")
class TradingDashboard:
"""Dashboard for visualizing trading activity with Dash"""
def __init__(self, symbol="ETH/USDT"):
self.symbol = symbol
self.env = None
self.candles = []
self.trade_signals = []
# Create Dash app
self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
# Create basic layout
self.app.layout = html.Div([
# Store components for data
html.Div(id='candle-store', style={'display': 'none'}),
html.Div(id='signal-store', style={'display': 'none'}),
# Header
html.H1(f"Trading Dashboard - {symbol}", style={'textAlign': 'center'}),
# Main content
html.Div([
# Chart
html.Div([
dcc.Graph(id='candlestick-chart', style={'height': '70vh'}),
dcc.Interval(id='interval-component', interval=5*1000, n_intervals=0)
], style={'width': '70%', 'display': 'inline-block'}),
# Trading info
html.Div([
html.Div([
html.H3("Account Info"),
html.Div(id='account-info')
]),
html.Div([
html.H3("Recent Trades"),
html.Div(id='recent-trades')
])
], style={'width': '30%', 'display': 'inline-block', 'verticalAlign': 'top'})
])
])
# Setup callbacks
self._setup_callbacks()
# Thread for running the server
self.thread = None
self.is_running = False
def _setup_callbacks(self):
@self.app.callback(
Output('candlestick-chart', 'figure'),
[Input('interval-component', 'n_intervals'),
Input('candle-store', 'children'),
Input('signal-store', 'children')]
)
def update_chart(n, candles_json, signals_json):
# Parse JSON data
candles = json.loads(candles_json) if candles_json else []
signals = json.loads(signals_json) if signals_json else []
# Create figure with subplots
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
vertical_spacing=0.1, row_heights=[0.7, 0.3])
if candles:
# Convert to dataframe
df = pd.DataFrame(candles[-100:]) # Show last 100 candles
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
# Add candlestick trace
fig.add_trace(
go.Candlestick(
x=df['timestamp'],
open=df['open'],
high=df['high'],
low=df['low'],
close=df['close'],
name='Price'
),
row=1, col=1
)
# Add volume trace
fig.add_trace(
go.Bar(
x=df['timestamp'],
y=df['volume'],
name='Volume'
),
row=2, col=1
)
# Add trade signals
for signal in signals:
if signal['timestamp'] >= df['timestamp'].iloc[0].timestamp() * 1000:
signal_time = pd.to_datetime(signal['timestamp'], unit='ms')
marker_color = 'green' if signal['type'] == 'buy' else 'red' if signal['type'] == 'sell' else 'orange'
marker_symbol = 'triangle-up' if signal['type'] == 'buy' else 'triangle-down' if signal['type'] == 'sell' else 'circle'
# Add marker for signal
fig.add_trace(
go.Scatter(
x=[signal_time],
y=[signal['price']],
mode='markers',
marker=dict(
color=marker_color,
size=12,
symbol=marker_symbol
),
name=signal['type'].capitalize(),
showlegend=False
),
row=1, col=1
)
# Update layout
fig.update_layout(
title=f'{self.symbol} Trading Chart',
xaxis_rangeslider_visible=False,
template='plotly_dark'
)
return fig
@self.app.callback(
[Output('account-info', 'children'),
Output('recent-trades', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_account_info(n):
if not self.env:
return "No data available", "No trades available"
# Account info
account_info = html.Div([
html.P(f"Balance: ${self.env.balance:.2f}"),
html.P(f"PnL: ${self.env.total_pnl:.2f}",
style={'color': 'green' if self.env.total_pnl > 0 else 'red' if self.env.total_pnl < 0 else 'white'}),
html.P(f"Position: {self.env.position.upper()}")
])
# Recent trades
if hasattr(self.env, 'trades') and self.env.trades:
# Get last 5 trades
recent_trades = []
for trade in reversed(self.env.trades[-5:]):
trade_card = html.Div([
html.P(f"{trade['action'].upper()} at ${trade['price']:.2f}"),
html.P(f"PnL: ${trade['pnl']:.2f}",
style={'color': 'green' if trade['pnl'] > 0 else 'red' if trade['pnl'] < 0 else 'white'})
], style={'border': '1px solid #ddd', 'padding': '10px', 'margin-bottom': '5px'})
recent_trades.append(trade_card)
else:
recent_trades = [html.P("No trades yet")]
return account_info, recent_trades
def update_data(self, env=None, candles=None, trade_signals=None):
"""Update dashboard data"""
if env:
self.env = env
if candles:
self.candles = candles
if trade_signals:
self.trade_signals = trade_signals
# Update store components
if hasattr(self.app, 'layout'):
self.app.layout.children[0].children = json.dumps(self.candles)
self.app.layout.children[1].children = json.dumps(self.trade_signals)
def start(self, host='localhost', port=8060):
"""Start the dashboard server in a separate thread"""
if not self.is_running:
# First check if the port is already in use
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
port_available = False
# Try the initial port and a few alternatives if needed
for attempt_port in range(port, port + 10):
try:
sock.bind((host, attempt_port))
port_available = True
port = attempt_port
break
except socket.error:
logger.warning(f"Port {attempt_port} is already in use")
sock.close()
if not port_available:
logger.error("Could not find an available port for dashboard")
return False
# Create and start the thread
self.thread = Thread(target=self._run_server, args=(host, port))
self.thread.daemon = True # This ensures the thread will exit when the main program does
self.thread.start()
self.is_running = True
logger.info(f"Trading dashboard started at http://{host}:{port}")
# Verify the thread actually started
if not self.thread.is_alive():
logger.error("Dashboard thread failed to start")
return False
# Wait a short time to let the server initialize
time.sleep(1.0)
return True
return False
def _run_server(self, host, port):
"""Run the Dash server"""
try:
logger.info(f"Starting Dash server on {host}:{port}")
self.app.run_server(debug=False, host=host, port=port, use_reloader=False, threaded=True)
except Exception as e:
logger.error(f"Error running dashboard server: {e}")
self.is_running = False
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program terminated by user")