gogo2/crypto/gogo2/main.py
2025-03-17 02:35:15 +02:00

2465 lines
99 KiB
Python

import os
import time
import json
import numpy as np
import pandas as pd
import datetime
import random
import logging
import asyncio
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
from dotenv import load_dotenv
import ccxt
import websockets
from torch.utils.tensorboard import SummaryWriter
import torch.cuda.amp as amp # Add this import at the top
from sklearn.preprocessing import MinMaxScaler
import copy
import argparse
import traceback
import io
import matplotlib.dates as mdates
from matplotlib.figure import Figure
from PIL import Image
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
)
logger = logging.getLogger("trading_bot")
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
# Constants
INITIAL_BALANCE = 100 # USD
MAX_LEVERAGE = 100
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
MEMORY_SIZE = 100000
BATCH_SIZE = 64
GAMMA = 0.99 # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 10000
STATE_SIZE = 40 # Size of our state representation
LEARNING_RATE = 1e-4
TARGET_UPDATE = 10 # Update target network every 10 episodes
# Experience replay tuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
# Add this function near the top of the file, after the imports but before any classes
def find_local_extrema(prices, window=5):
"""Find local minima (bottoms) and maxima (tops) in price data"""
bottoms = []
tops = []
if len(prices) < window * 2 + 1:
return bottoms, tops
for i in range(window, len(prices) - window):
# Check if this is a local minimum (bottom)
if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \
all(prices[i] <= prices[i+j] for j in range(1, window+1)):
bottoms.append(i)
# Check if this is a local maximum (top)
if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \
all(prices[i] >= prices[i+j] for j in range(1, window+1)):
tops.append(i)
return bottoms, tops
class ReplayMemory:
def __init__(self, capacity):
self.memory = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.memory.append(Experience(state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
super(DQN, self).__init__()
self.state_size = state_size
self.hidden_size = hidden_size
self.lstm_layers = lstm_layers
# Initial feature extraction
self.fc1 = nn.Linear(state_size, hidden_size)
# Use LayerNorm instead of BatchNorm for more stability with varying batch sizes
self.ln1 = nn.LayerNorm(hidden_size)
self.dropout1 = nn.Dropout(0.2)
# LSTM layer for sequential data
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2)
# Attention mechanism
self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
# Output layers with increased capacity
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm
self.dropout2 = nn.Dropout(0.2)
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
# Dueling DQN architecture
self.value_stream = nn.Linear(hidden_size // 2, 1)
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
# Transformer encoder for more complex pattern recognition
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
def forward(self, x):
batch_size = x.size(0) if x.dim() > 1 else 1
# Ensure input has correct shape
if x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension
# Check if state size matches expected input size
if x.size(1) != self.state_size:
# Handle mismatched input by either truncating or padding
if x.size(1) > self.state_size:
x = x[:, :self.state_size] # Truncate
else:
# Pad with zeros
padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device)
x = torch.cat([x, padding], dim=1)
# Initial feature extraction
x = self.fc1(x)
x = F.relu(self.ln1(x)) # LayerNorm works with any batch size
x = self.dropout1(x)
# Reshape for LSTM
x_lstm = x.unsqueeze(1) if x.dim() == 2 else x
# Process through LSTM
lstm_out, _ = self.lstm(x_lstm)
lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1]
# Process through transformer for more complex patterns
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
# Combine LSTM and transformer outputs
x = lstm_out + transformer_out
# Final layers
x = self.fc2(x)
x = F.relu(self.ln2(x)) # LayerNorm works with any batch size
x = self.dropout2(x)
x = F.relu(self.fc3(x))
# Dueling architecture
value = self.value_stream(x)
advantages = self.advantage_stream(x)
qvals = value + (advantages - advantages.mean(dim=1, keepdim=True))
return qvals
class PricePredictionModel(nn.Module):
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
super(PricePredictionModel, self).__init__()
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
self.fc = nn.Linear(hidden_size, output_size)
self.scaler = MinMaxScaler(feature_range=(0, 1))
self.is_fitted = False
def forward(self, x):
# x shape: [batch_size, seq_len, 1]
lstm_out, _ = self.lstm(x)
# Use the last time step output
predictions = self.fc(lstm_out[:, -1, :])
return predictions
def preprocess(self, data):
# Reshape data for scaler
data_reshaped = np.array(data).reshape(-1, 1)
# Fit scaler if not already fitted
if not self.is_fitted:
self.scaler.fit(data_reshaped)
self.is_fitted = True
# Transform data
scaled_data = self.scaler.transform(data_reshaped)
return scaled_data
def postprocess(self, scaled_predictions):
# Inverse transform to get actual price values
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
def predict_next_candles(self, price_history, num_candles=5):
if len(price_history) < 30: # Need enough history
return np.zeros(num_candles)
# Preprocess data
scaled_data = self.preprocess(price_history)
# Create sequence
sequence = scaled_data[-30:].reshape(1, 30, 1)
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
# Get predictions
with torch.no_grad():
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
# Postprocess predictions
predictions = self.postprocess(scaled_predictions)
return predictions
def train_on_new_data(self, price_history, optimizer, epochs=10):
if len(price_history) < 35: # Need enough history for training
return 0.0
# Preprocess data
scaled_data = self.preprocess(price_history)
# Create sequences and targets
sequences = []
targets = []
for i in range(len(scaled_data) - 35):
# Sequence: 30 time steps
seq = scaled_data[i:i+30]
# Target: next 5 time steps
target = scaled_data[i+30:i+35].flatten()
sequences.append(seq)
targets.append(target)
if not sequences: # If no sequences were created
return 0.0
# Convert to tensors
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
# Training loop
total_loss = 0
for _ in range(epochs):
# Forward pass
predictions = self(sequences_tensor)
# Calculate loss
loss = F.mse_loss(predictions, targets_tensor)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / epochs
class TradingEnvironment:
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
"""Initialize the trading environment"""
self.initial_balance = initial_balance
self.balance = initial_balance
self.window_size = window_size
self.demo = demo
self.data = []
self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.trades = []
self.win_count = 0
self.loss_count = 0
self.total_pnl = 0.0
self.episode_pnl = 0.0
self.peak_balance = initial_balance
self.max_drawdown = 0.0
self.current_step = 0
self.current_price = 0
# For tracking signals for visualization
self.trade_signals = []
# Initialize features
self.features = {
'price': [],
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
# Initialize price predictor
self.price_predictor = None
self.predicted_prices = np.array([])
# Initialize optimal trade tracking
self.optimal_bottoms = []
self.optimal_tops = []
self.optimal_signals = np.array([])
# Add these new attributes
self.leverage = MAX_LEVERAGE
self.futures_symbol = "ETH_USDT" # Example futures symbol
self.position_mode = "hedge" # For simultaneous long/short positions
self.margin_mode = "cross" # Cross margin mode
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.trades = []
self.win_count = 0
self.loss_count = 0
self.episode_pnl = 0.0
self.peak_balance = self.initial_balance
self.max_drawdown = 0.0
self.current_step = 0
# Keep data but reset current position
if len(self.data) > self.window_size:
self.current_step = self.window_size
self.current_price = self.data[self.current_step]['close']
# Reset trade signals
self.trade_signals = []
return self.get_state()
def add_data(self, candle):
"""Add a new candle to the data"""
self.data.append(candle)
self._update_features()
self.current_price = candle['close']
def _initialize_features(self):
"""Initialize technical indicators and features"""
if len(self.data) < 30:
return
# Convert data to pandas DataFrame for easier calculation
df = pd.DataFrame(self.data)
# Basic price and volume
self.features['price'] = df['close'].values
self.features['volume'] = df['volume'].values
# Calculate RSI (14 periods)
delta = df['close'].diff()
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
rs = gain / loss
self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values
# Calculate MACD
ema12 = df['close'].ewm(span=12, adjust=False).mean()
ema26 = df['close'].ewm(span=26, adjust=False).mean()
macd = ema12 - ema26
signal = macd.ewm(span=9, adjust=False).mean()
self.features['macd'] = macd.values
self.features['macd_signal'] = signal.values
self.features['macd_hist'] = (macd - signal).values
# Calculate Bollinger Bands
sma20 = df['close'].rolling(window=20).mean()
std20 = df['close'].rolling(window=20).std()
self.features['bollinger_upper'] = (sma20 + 2 * std20).values
self.features['bollinger_mid'] = sma20.values
self.features['bollinger_lower'] = (sma20 - 2 * std20).values
# Calculate Stochastic Oscillator
low_14 = df['low'].rolling(window=14).min()
high_14 = df['high'].rolling(window=14).max()
k = 100 * ((df['close'] - low_14) / (high_14 - low_14))
self.features['stoch_k'] = k.values
self.features['stoch_d'] = k.rolling(window=3).mean().values
# Calculate EMAs
self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values
self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values
# Calculate ATR
high_low = df['high'] - df['low']
high_close = (df['high'] - df['close'].shift()).abs()
low_close = (df['low'] - df['close'].shift()).abs()
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values
def _update_features(self):
"""Update technical indicators with new data"""
self._initialize_features() # Recalculate all features
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""Fetch initial historical data for the environment"""
try:
logger.info(f"Fetching initial data for {symbol}")
# Use the refactored fetch method
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
# Update environment with fetched data
if data:
self.data = data
self._initialize_features()
logger.info(f"Initialized environment with {len(data)} candles")
else:
logger.warning("No initial data received")
return len(data) > 0
except Exception as e:
logger.error(f"Error fetching initial data: {e}")
return False
def step(self, action):
"""Take an action in the environment and return the next state, reward, and done flag"""
# Store current price before taking action
self.current_price = self.data[self.current_step]['close']
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
reward = self.calculate_reward(action)
# Record trade signal for visualization
if action > 0: # If not HOLD
signal_type = None
if action == 1: # BUY/LONG
signal_type = 'buy'
elif action == 2: # SELL/SHORT
signal_type = 'sell'
elif action == 3: # CLOSE
if self.position == 'long':
signal_type = 'close_long'
elif self.position == 'short':
signal_type = 'close_short'
if signal_type:
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.current_price,
'type': signal_type,
'balance': self.balance,
'pnl': self.total_pnl
})
# Check for stop loss / take profit hits
self.check_sl_tp()
# Move to next step
self.current_step += 1
done = self.current_step >= len(self.data) - 1
# Get new state
next_state = self.get_state()
return next_state, reward, done
def check_sl_tp(self):
"""Check if stop loss or take profit has been hit"""
if self.position == 'flat':
return
if self.position == 'long':
# Check stop loss
if self.current_price <= self.stop_loss:
# Stop loss hit
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.stop_loss,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'stop_loss'
})
# Update win/loss count
self.loss_count += 1
logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.stop_loss,
'type': 'stop_loss_long',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif self.current_price >= self.take_profit:
# Take profit hit
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.take_profit,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'take_profit'
})
# Update win/loss count
self.win_count += 1
logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.take_profit,
'type': 'take_profit_long',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
elif self.position == 'short':
# Check stop loss
if self.current_price >= self.stop_loss:
# Stop loss hit
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.stop_loss,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'stop_loss'
})
# Update win/loss count
self.loss_count += 1
logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.stop_loss,
'type': 'stop_loss_short',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif self.current_price <= self.take_profit:
# Take profit hit
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
# Record trade
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.take_profit,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': 'take_profit'
})
# Update win/loss count
self.win_count += 1
logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Record signal for visualization
self.trade_signals.append({
'timestamp': self.data[self.current_step]['timestamp'],
'price': self.take_profit,
'type': 'take_profit_short',
'balance': self.balance,
'pnl': self.total_pnl
})
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
def get_state(self):
"""Create state representation for the agent with enhanced features"""
if len(self.data) < 30 or len(self.features['price']) == 0:
# Return zeros if not enough data
return np.zeros(STATE_SIZE)
# Create a normalized state vector with recent price action and indicators
state_components = []
# Price features (normalize recent prices by the latest price)
latest_price = self.features['price'][-1]
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
state_components.append(price_features)
# Volume features (normalize by max volume)
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
vol_features = np.array(self.features['volume'][-5:]) / max_vol
state_components.append(vol_features)
# Technical indicators
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
state_components.append(rsi)
# MACD (normalize)
macd_vals = np.array(self.features['macd'][-3:])
macd_signal = np.array(self.features['macd_signal'][-3:])
macd_hist = np.array(self.features['macd_hist'][-3:])
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
macd_norm = macd_vals / macd_scale
macd_signal_norm = macd_signal / macd_scale
macd_hist_norm = macd_hist / macd_scale
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
# Bollinger position (where is price relative to bands)
bb_upper = np.array(self.features['bollinger_upper'][-3:])
bb_lower = np.array(self.features['bollinger_lower'][-3:])
bb_mid = np.array(self.features['bollinger_mid'][-3:])
price = np.array(self.features['price'][-3:])
# Calculate position of price within Bollinger Bands (0 to 1)
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
state_components.append(np.array(bb_pos))
# Stochastic oscillator
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
# Add predicted prices (if available)
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
# Normalize predictions relative to current price
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
state_components.append(pred_norm)
else:
# Add zeros if no predictions
state_components.append(np.zeros(3))
# Add extrema signals (if available)
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
# Get recent signals
idx = len(self.optimal_signals) - 5
if idx < 0:
idx = 0
recent_signals = self.optimal_signals[idx:idx+5]
# Pad if needed
if len(recent_signals) < 5:
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
state_components.append(recent_signals)
else:
# Add zeros if no signals
state_components.append(np.zeros(5))
# Position info
position_info = np.zeros(5)
if self.position == 'long':
position_info[0] = 1.0 # Position is long
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
position_info[4] = self.position_size / self.balance # Position size relative to balance
elif self.position == 'short':
position_info[0] = -1.0 # Position is short
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
position_info[4] = self.position_size / self.balance # Position size relative to balance
state_components.append(position_info)
# NEW FEATURES START HERE
# 1. Price momentum features (rate of change over different periods)
if len(self.features['price']) >= 20:
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
momentum_features = np.array([roc_5, roc_10, roc_20])
state_components.append(momentum_features)
else:
state_components.append(np.zeros(3))
# 2. Volatility features
if len(self.features['price']) >= 20:
# Calculate price returns
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
# Calculate volatility (standard deviation of returns)
volatility = np.std(returns)
# Calculate normalized high-low range
high_low_range = np.mean([
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
for i in range(max(0, len(self.data)-5), len(self.data))
]) if len(self.data) > 0 else 0
# ATR normalized by price
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
volatility_features = np.array([volatility, high_low_range, atr_norm])
state_components.append(volatility_features)
else:
state_components.append(np.zeros(3))
# 3. Market regime features
if len(self.features['price']) >= 50:
# Trend strength (ADX-like measure)
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
trend_strength = abs(ema9 - ema21) / ema21
# Detect if in range or trending
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
# Detect if near support/resistance
near_support = 1.0 if self.is_near_support() else 0.0
near_resistance = 1.0 if self.is_near_resistance() else 0.0
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
state_components.append(market_regime)
else:
state_components.append(np.zeros(5))
# 4. Trade history features
if len(self.trades) > 0:
# Recent win/loss ratio
recent_trades = self.trades[-min(10, len(self.trades)):]
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
# Average profit/loss
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
# Normalize by balance
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
# Last trade result
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
state_components.append(trade_history)
else:
state_components.append(np.zeros(4))
# Combine all features
state = np.concatenate([comp.flatten() for comp in state_components])
# Replace any NaN values
state = np.nan_to_num(state, nan=0.0)
# Return the state (the caller will handle sizing)
return state
def get_expanded_state_size(self):
"""Calculate the size of the expanded state representation"""
# Create a dummy state to get its size
state = self.get_state()
return len(state)
async def expand_model_with_new_features(agent, env):
"""Expand the model to handle new features without retraining from scratch"""
# Get the new state size
new_state_size = env.get_expanded_state_size()
# Only expand if the new state size is larger
if new_state_size > agent.state_size:
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
# Expand the model
success = agent.expand_model(
new_state_size=new_state_size,
new_hidden_size=512, # Increase hidden size for more capacity
new_lstm_layers=3, # More layers for deeper patterns
new_attention_heads=8 # More attention heads for complex relationships
)
if success:
logger.info(f"Model successfully expanded to handle {new_state_size} features")
return True
else:
logger.error("Failed to expand model")
return False
else:
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
return True
def calculate_reward(self, action):
"""Calculate reward for the given action with improved penalties for losing trades"""
reward = 0
# Base reward for actions
if action == 0: # HOLD
reward = -0.01 # Small penalty for doing nothing
elif action == 1: # BUY/LONG
if self.position == 'flat':
# Opening a long position
self.position = 'long'
self.entry_price = self.current_price
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal buy point (bottom)
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
reward += 2.0 # Bonus for buying at a bottom
else:
# Check if we're buying in a downtrend (bad)
if self.is_downtrend():
reward -= 0.5 # Penalty for buying in downtrend
else:
reward += 0.1 # Small reward for opening a position
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif self.position == 'short':
# Close short and open long
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
trade_duration = len(self.features['price']) - self.entry_index
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': trade_duration,
'market_direction': self.get_market_direction()
})
# Reward based on PnL with stronger penalties for losses
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
# Stronger penalty for losses, scaled by the size of the loss
loss_penalty = 1.0 + abs(pnl_dollar) / 5
reward -= loss_penalty
self.loss_count += 1
# Extra penalty for closing a losing trade too quickly
if trade_duration < 5:
reward -= 0.5 # Penalty for very short losing trades
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Now open long
self.position = 'long'
self.entry_price = self.current_price
self.entry_index = len(self.features['price']) - 1
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal buy point
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
reward += 2.0 # Bonus for buying at a bottom
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif action == 2: # SELL/SHORT
if self.position == 'flat':
# Opening a short position
self.position = 'short'
self.entry_price = self.current_price
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal sell point (top)
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
reward += 2.0 # Bonus for selling at a top
else:
reward += 0.1 # Small reward for opening a position
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif self.position == 'long':
# Close long and open short
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar
})
# Reward based on PnL
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
reward -= 1.0 # Negative reward for loss
self.loss_count += 1
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Now open short
self.position = 'short'
self.entry_price = self.current_price
self.position_size = self.calculate_position_size()
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
# Check if this is an optimal sell point
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
reward += 2.0 # Bonus for selling at a top
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
elif action == 3: # CLOSE
if self.position == 'long':
# Close long position
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar
})
# Reward based on PnL
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
reward -= 1.0 # Negative reward for loss
self.loss_count += 1
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
elif self.position == 'short':
# Close short position
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees
pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar
})
# Reward based on PnL
if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
self.win_count += 1
else:
reward -= 1.0 # Negative reward for loss
self.loss_count += 1
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Add prediction accuracy component to reward
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
# Compare the first prediction with actual price
if len(self.data) > 1:
actual_price = self.data[-1]['close']
predicted_price = self.predicted_prices[0]
prediction_error = abs(predicted_price - actual_price) / actual_price
# Reward accurate predictions, penalize bad ones
if prediction_error < 0.005: # Less than 0.5% error
reward += 0.5
elif prediction_error > 0.02: # More than 2% error
reward -= 0.5
return reward
def is_downtrend(self):
"""Check if the market is in a downtrend"""
if len(self.features['price']) < 20:
return False
# Use EMA to determine trend
short_ema = self.features['ema_9'][-1]
long_ema = self.features['ema_21'][-1]
# Downtrend if short EMA is below long EMA
return short_ema < long_ema
def is_uptrend(self):
"""Check if the market is in an uptrend"""
if len(self.features['price']) < 20:
return False
# Use EMA to determine trend
short_ema = self.features['ema_9'][-1]
long_ema = self.features['ema_21'][-1]
# Uptrend if short EMA is above long EMA
return short_ema > long_ema
def get_market_direction(self):
"""Get the current market direction"""
if self.is_uptrend():
return "uptrend"
elif self.is_downtrend():
return "downtrend"
else:
return "sideways"
def analyze_trades(self):
"""Analyze completed trades to identify patterns"""
if not self.trades:
return {}
analysis = {
'total_trades': len(self.trades),
'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0),
'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0),
'avg_win': 0,
'avg_loss': 0,
'avg_duration': 0,
'uptrend_win_rate': 0,
'downtrend_win_rate': 0,
'sideways_win_rate': 0
}
# Calculate averages
wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0]
losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0]
durations = [t.get('duration', 0) for t in self.trades]
analysis['avg_win'] = sum(wins) / len(wins) if wins else 0
analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0
analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0
# Calculate win rates by market direction
for direction in ['uptrend', 'downtrend', 'sideways']:
direction_trades = [t for t in self.trades if t.get('market_direction') == direction]
if direction_trades:
wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0)
analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100
return analysis
def initialize_price_predictor(self, device="cpu"):
"""Initialize the price prediction model"""
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
self.price_predictor.to(device)
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
self.predicted_prices = np.array([])
def train_price_predictor(self):
"""Train the price prediction model on recent data"""
if len(self.features['price']) < 35:
return 0.0
# Get price history
price_history = self.features['price']
# Train the model
loss = self.price_predictor.train_on_new_data(
price_history,
self.price_predictor_optimizer,
epochs=5
)
return loss
def update_price_predictions(self):
"""Update price predictions"""
if len(self.features['price']) < 30:
self.predicted_prices = np.array([])
return
# Get price history
price_history = self.features['price']
# Get predictions
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
def identify_optimal_trades(self):
"""Identify optimal entry and exit points based on local extrema"""
if len(self.features['price']) < 20:
return
# Find local bottoms and tops
bottoms, tops = find_local_extrema(self.features['price'], window=5)
# Store optimal trade points
self.optimal_bottoms = bottoms # Buy points
self.optimal_tops = tops # Sell points
# Create optimal trade signals
self.optimal_signals = np.zeros(len(self.features['price']))
for i in bottoms:
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
self.optimal_signals[i] = 1 # Buy signal
for i in tops:
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
self.optimal_signals[i] = -1 # Sell signal
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
def calculate_position_size(self):
"""Calculate position size based on current balance and risk parameters"""
# Use a fixed percentage of balance for each trade
risk_percent = 5.0 # Risk 5% of balance per trade
# Calculate position size with leverage
position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE
# Apply a safety factor to avoid liquidation
safety_factor = 0.8
position_size *= safety_factor
# Ensure minimum position size
min_position = 10.0 # Minimum position size in USD
position_size = max(position_size, min(min_position, self.balance * 0.5))
# Ensure position size doesn't exceed balance * leverage
max_position = self.balance * MAX_LEVERAGE
position_size = min(position_size, max_position)
return position_size
def calculate_fees(self, position_size):
"""Calculate trading fees for a given position size"""
# Typical fee rate for crypto exchanges (0.1%)
fee_rate = 0.001
# Calculate fee
fee = position_size * fee_rate
return fee
def is_uncertain_market(self):
"""Check if the market is in an uncertain/sideways state"""
if len(self.features['price']) < 20:
return True
# Check if price is within a narrow range
recent_prices = self.features['price'][-20:]
price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices)
# Check if EMAs are close to each other
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
short_ema = self.features['ema_9'][-1]
long_ema = self.features['ema_21'][-1]
ema_diff = abs(short_ema - long_ema) / long_ema
# Return True if price range is small and EMAs are close
return price_range < 0.02 and ema_diff < 0.005
return price_range < 0.015 # Very narrow range
def is_near_support(self):
"""Check if current price is near a support level"""
if not hasattr(self, 'features') or len(self.features['price']) < 30:
return False
# Find recent lows
prices = self.features['price'][-30:]
lows = []
for i in range(1, len(prices)-1):
if prices[i] < prices[i-1] and prices[i] < prices[i+1]:
lows.append(prices[i])
if not lows:
return False
# Check if current price is near any of these lows
current_price = self.current_price
for low in lows:
if abs(current_price - low) / low < 0.01: # Within 1% of a recent low
return True
return False
def is_near_resistance(self):
"""Check if current price is near a resistance level"""
if not hasattr(self, 'features') or len(self.features['price']) < 30:
return False
# Find recent highs
prices = self.features['price'][-30:]
highs = []
for i in range(1, len(prices)-1):
if prices[i] > prices[i-1] and prices[i] > prices[i+1]:
highs.append(prices[i])
if not highs:
return False
# Check if current price is near any of these highs
current_price = self.current_price
for high in highs:
if abs(current_price - high) / high < 0.01: # Within 1% of a recent high
return True
return False
def is_market_turning(self):
"""Check if the market is potentially changing direction"""
if len(self.features['price']) < 20:
return False
# Check for divergence between price and momentum indicators
if len(self.features['rsi']) > 5:
# Price making higher highs but RSI making lower highs (bearish divergence)
price_trend = self.features['price'][-1] > self.features['price'][-5]
rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5]
if price_trend != rsi_trend:
return True
# Check for EMA crossover
if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1:
short_ema_prev = self.features['ema_9'][-2]
long_ema_prev = self.features['ema_21'][-2]
short_ema_curr = self.features['ema_9'][-1]
long_ema_curr = self.features['ema_21'][-1]
# Check if EMAs just crossed
if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \
(short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr):
return True
return False
def is_market_against_position(self, position_type):
"""Check if market conditions have turned against the current position"""
if position_type == 'long':
# For long positions, check if market has turned bearish
return self.is_downtrend() and not self.is_near_support()
elif position_type == 'short':
# For short positions, check if market has turned bullish
return self.is_uptrend() and not self.is_near_resistance()
return False
def is_near_optimal_exit(self, position_type):
"""Check if current price is near an optimal exit point for the position"""
current_idx = len(self.features['price']) - 1
if position_type == 'long' and hasattr(self, 'optimal_tops'):
# For long positions, optimal exit is near tops
for top_idx in self.optimal_tops:
if abs(current_idx - top_idx) < 3: # Within 3 candles of a top
return True
elif position_type == 'short' and hasattr(self, 'optimal_bottoms'):
# For short positions, optimal exit is near bottoms
for bottom_idx in self.optimal_bottoms:
if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom
return True
return False
def calculate_future_profit_potential(self, position_type, lookahead=20):
"""
Calculate potential profit if position is held for a certain period
This is used for retrospective backtesting rewards
Args:
position_type: 'long' or 'short'
lookahead: Number of candles to look ahead
Returns:
Potential profit percentage
"""
if len(self.data) <= 1:
return 0
# Get current price
current_price = self.current_price
# Get future prices (if available in historical data)
future_prices = []
current_idx = self.current_step
for i in range(1, min(lookahead + 1, len(self.data) - current_idx)):
if current_idx + i < len(self.data):
future_prices.append(self.data[current_idx + i]['close'])
if not future_prices:
return 0
# Calculate potential profit
if position_type == 'long':
# For long positions, find the maximum price in the future
max_future_price = max(future_prices)
potential_profit = (max_future_price - current_price) / current_price * 100
else: # short
# For short positions, find the minimum price in the future
min_future_price = min(future_prices)
potential_profit = (current_price - min_future_price) / current_price * 100
return potential_profit
async def initialize_futures(self, exchange):
"""Initialize futures trading parameters"""
if not self.demo:
try:
# Set up futures trading parameters
await exchange.set_position_mode(True) # Hedge mode
await exchange.set_margin_mode("cross", symbol=self.futures_symbol)
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
logger.info(f"Futures initialized with {self.leverage}x leverage")
except Exception as e:
logger.error(f"Failed to initialize futures: {e}")
raise
async def execute_real_trade(self, exchange, action, current_price):
"""Execute real futures trade on MEXC"""
try:
position_size = self.calculate_position_size()
if action == 1: # Open long
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='buy',
amount=position_size,
params={'positionSide': 'LONG'}
)
logger.info(f"Opened LONG position: {order}")
elif action == 2: # Open short
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='sell',
amount=position_size,
params={'positionSide': 'SHORT'}
)
logger.info(f"Opened SHORT position: {order}")
elif action == 3: # Close position
position_side = 'LONG' if self.position == 'long' else 'SHORT'
order = await exchange.create_order(
symbol=self.futures_symbol,
type='market',
side='sell' if position_side == 'LONG' else 'buy',
amount=self.position_size,
params={'positionSide': position_side}
)
logger.info(f"Closed {position_side} position: {order}")
return order
except Exception as e:
logger.error(f"Trade execution failed: {e}")
return None
# Ensure GPU usage if available
def get_device():
"""Get the best available device (CUDA GPU or CPU)"""
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
# Set up for mixed precision training
torch.backends.cudnn.benchmark = True
else:
device = torch.device("cpu")
logger.info("GPU not available, using CPU")
return device
# Update Agent class to use GPU properly
class Agent:
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
device=None):
if device is None:
self.device = get_device()
else:
self.device = device
self.state_size = state_size
self.action_size = action_size
self.memory = ReplayMemory(MEMORY_SIZE)
self.steps_done = 0
self.epsilon = EPSILON_START # Initialize epsilon
# Initialize policy and target networks
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
# Initialize optimizer with weight decay for regularization
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
# Initialize gradient scaler for mixed precision training
self.scaler = amp.GradScaler()
# TensorBoard writer
self.writer = SummaryWriter()
# For chart visualization
self.chart_step = 0
# Create models directory if it doesn't exist
os.makedirs("models", exist_ok=True)
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
"""Expand the model to handle more features or increase capacity"""
logger.info(f"Expanding model: {self.state_size}{new_state_size}, "
f"hidden: {self.policy_net.hidden_size}{new_hidden_size}")
# Save old weights
old_state_dict = self.policy_net.state_dict()
# Create new larger networks
new_policy_net = DQN(new_state_size, self.action_size,
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
new_target_net = DQN(new_state_size, self.action_size,
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
# Transfer weights for common layers
new_state_dict = new_policy_net.state_dict()
for name, param in old_state_dict.items():
if name in new_state_dict:
# If shapes match, copy directly
if new_state_dict[name].shape == param.shape:
new_state_dict[name] = param
# For first layer, copy weights for the original input dimensions
elif name == "fc1.weight":
new_state_dict[name][:, :self.state_size] = param
# For other layers, initialize with a strategy that preserves scale
else:
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
# Load transferred weights
new_policy_net.load_state_dict(new_state_dict)
new_target_net.load_state_dict(new_state_dict)
# Replace networks
self.policy_net = new_policy_net
self.target_net = new_target_net
self.target_net.eval()
# Update optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Update state size
self.state_size = new_state_size
# Print new model size
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"New model size: {total_params:,} parameters")
return True
def select_action(self, state, training=True):
sample = random.random()
if training:
# Epsilon decay
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
np.exp(-1. * self.steps_done / EPSILON_DECAY)
self.steps_done += 1
if sample > self.epsilon or not training:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
action_values = self.policy_net(state_tensor)
return action_values.max(1)[1].item()
else:
return random.randrange(self.action_size)
def learn(self):
"""Learn from a batch of experiences"""
if len(self.memory) < BATCH_SIZE:
return None
try:
# Sample a batch of experiences
experiences = self.memory.sample(BATCH_SIZE)
# Convert experiences to tensors
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
# Use mixed precision for forward/backward passes
if self.device.type == "cuda":
with amp.autocast():
# Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Compute next Q values with target network
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
# Reshape target values to match current_q_values
target_q_values = target_q_values.unsqueeze(1)
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Backward pass with mixed precision
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
# Gradient clipping to prevent exploding gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard precision for CPU
# Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Compute next Q values with target network
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
# Reshape target values to match current_q_values
target_q_values = target_q_values.unsqueeze(1)
# Compute loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
# Update steps done
self.steps_done += 1
# Update target network
if self.steps_done % TARGET_UPDATE == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
return loss.item()
except Exception as e:
logger.error(f"Error during learning: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
return None
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path="models/trading_agent.pt"):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'steps_done': self.steps_done
}, path)
logger.info(f"Model saved to {path}")
def load(self, path="models/trading_agent.pt"):
if os.path.isfile(path):
try:
# First try with weights_only=True (safer)
checkpoint = torch.load(path, weights_only=True)
except Exception as e:
logger.warning(f"Failed to load with weights_only=True: {e}")
try:
# Try with safe_globals for numpy.scalar
import numpy as np
from torch.serialization import safe_globals
with safe_globals([np.core.multiarray.scalar]):
checkpoint = torch.load(path, weights_only=True)
except Exception as e2:
logger.warning(f"Failed with safe_globals: {e2}")
# Fall back to weights_only=False if needed
checkpoint = torch.load(path, weights_only=False)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.epsilon = checkpoint['epsilon']
self.steps_done = checkpoint['steps_done']
logger.info(f"Model loaded from {path}")
return True
logger.warning(f"No model found at {path}")
return False
def add_chart_to_tensorboard(self, env, global_step):
"""Add enhanced trading chart to TensorBoard"""
if len(env.data) < 10: # Minimum data to show
return
try:
# Create chart with annotations
chart_img = create_candlestick_figure(
env.data,
env.trade_signals,
window_size=100,
title=f"Trading Chart (Step {global_step})"
)
# Add to TensorBoard
self.writer.add_image('Trading Chart', np.array(chart_img).transpose(2, 0, 1), global_step)
self.chart_step = global_step
# Also log position information
if env.position != 'flat':
position_info = {
'position_type': env.position,
'entry_price': env.entry_price,
'position_size': env.position_size,
'unrealized_pnl': env.total_pnl
}
self.writer.add_text('Position', str(position_info), global_step)
except Exception as e:
logger.error(f"Error creating chart: {e}")
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
"""Get live price data using websockets"""
# Connect to MEXC websocket
uri = "wss://stream.mexc.com/ws"
async with websockets.connect(uri) as websocket:
# Subscribe to kline data
subscribe_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
}
await websocket.send(json.dumps(subscribe_msg))
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
while True:
try:
response = await websocket.recv()
data = json.loads(response)
if 'data' in data:
kline = data['data']
candle = {
'timestamp': kline['t'],
'open': float(kline['o']),
'high': float(kline['h']),
'low': float(kline['l']),
'close': float(kline['c']),
'volume': float(kline['v'])
}
yield candle
except Exception as e:
logger.error(f"Websocket error: {e}")
# Try to reconnect
await asyncio.sleep(5)
break
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
"""Train the agent using historical and live data with GPU acceleration"""
logger.info(f"Starting training on device: {agent.device}")
stats = {
'episode_rewards': [],
'episode_lengths': [],
'balances': [],
'win_rates': [],
'episode_pnls': [],
'cumulative_pnl': [],
'drawdowns': [],
'prediction_accuracy': [],
'trade_analysis': []
}
best_reward = -float('inf')
best_pnl = -float('inf')
try:
# Initialize price predictor
env.initialize_price_predictor(agent.device)
for episode in range(num_episodes):
try:
# Reset environment
state = env.reset()
episode_reward = 0
env.episode_pnl = 0.0 # Reset episode PnL
# Identify optimal trade points for this episode
env.identify_optimal_trades()
# Train price predictor
prediction_loss = env.train_price_predictor()
# Update price predictions
env.update_price_predictions()
for step in range(max_steps_per_episode):
# Select action
action = agent.select_action(state)
# Take action
next_state, reward, done = env.step(action)
# Store experience
agent.memory.push(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
# Learn from experience with mixed precision
try:
loss = agent.learn()
if loss is not None:
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
except Exception as e:
logger.error(f"Learning error in episode {episode}, step {step}: {e}")
# Update price predictions periodically
if step % 10 == 0:
env.update_price_predictions()
# Add chart to TensorBoard periodically
if step % 50 == 0 or (step == max_steps_per_episode - 1) or done:
global_step = episode * max_steps_per_episode + step
agent.add_chart_to_tensorboard(env, global_step)
if done:
break
# Update target network
if episode % TARGET_UPDATE == 0:
agent.target_net.load_state_dict(agent.policy_net.state_dict())
# Calculate win rate
if len(env.trades) > 0:
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
win_rate = wins / len(env.trades) * 100
else:
win_rate = 0
# Analyze trades
trade_analysis = env.analyze_trades()
stats['trade_analysis'].append(trade_analysis)
# Calculate prediction accuracy
prediction_accuracy = 0.0
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
if len(env.data) > 5:
actual_prices = [candle['close'] for candle in env.data[-5:]]
predicted = env.predicted_prices[:min(5, len(actual_prices))]
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
# Log statistics
stats['episode_rewards'].append(episode_reward)
stats['episode_lengths'].append(step + 1)
stats['balances'].append(env.balance)
stats['win_rates'].append(win_rate)
stats['episode_pnls'].append(env.episode_pnl)
stats['cumulative_pnl'].append(env.total_pnl)
stats['drawdowns'].append(env.max_drawdown * 100)
stats['prediction_accuracy'].append(prediction_accuracy)
# Log detailed trade analysis
if trade_analysis:
logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | "
f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}")
# Log to TensorBoard
agent.writer.add_scalar('Reward/train', episode_reward, episode)
agent.writer.add_scalar('Balance/train', env.balance, episode)
agent.writer.add_scalar('WinRate/train', win_rate, episode)
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
# Add final chart for this episode
agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode)
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%")
# Save best model by reward
if episode_reward > best_reward:
best_reward = episode_reward
agent.save("models/trading_agent_best_reward.pt")
# Save best model by PnL
if env.episode_pnl > best_pnl:
best_pnl = env.episode_pnl
agent.save("models/trading_agent_best_pnl.pt")
# Save checkpoint
if episode % 10 == 0:
agent.save(f"models/trading_agent_episode_{episode}.pt")
except Exception as e:
logger.error(f"Error in episode {episode}: {e}")
continue
# Save final model
agent.save("models/trading_agent_final.pt")
# Plot training results
plot_training_results(stats)
except Exception as e:
logger.error(f"Training failed: {e}")
raise
return stats
def plot_training_results(stats):
"""Plot detailed training results"""
plt.figure(figsize=(20, 15))
# Plot rewards
plt.subplot(3, 2, 1)
plt.plot(stats['episode_rewards'])
plt.title('Episode Rewards')
plt.xlabel('Episode')
plt.ylabel('Reward')
# Plot balance
plt.subplot(3, 2, 2)
plt.plot(stats['balances'])
plt.title('Account Balance')
plt.xlabel('Episode')
plt.ylabel('Balance ($)')
# Plot win rate
plt.subplot(3, 2, 3)
plt.plot(stats['win_rates'])
plt.title('Win Rate')
plt.xlabel('Episode')
plt.ylabel('Win Rate (%)')
# Plot episode PnL
plt.subplot(3, 2, 4)
plt.plot(stats['episode_pnls'])
plt.title('Episode PnL')
plt.xlabel('Episode')
plt.ylabel('PnL ($)')
# Plot cumulative PnL
plt.subplot(3, 2, 5)
plt.plot(stats['cumulative_pnl'])
plt.title('Cumulative PnL')
plt.xlabel('Episode')
plt.ylabel('Cumulative PnL ($)')
# Plot drawdown
plt.subplot(3, 2, 6)
plt.plot(stats['drawdowns'])
plt.title('Maximum Drawdown')
plt.xlabel('Episode')
plt.ylabel('Drawdown (%)')
plt.tight_layout()
plt.savefig('training_results.png')
# Save statistics to CSV
df = pd.DataFrame(stats)
df.to_csv('training_stats.csv', index=False)
logger.info("Training statistics saved to training_stats.csv and training_results.png")
def evaluate_agent(agent, env, num_episodes=10):
"""Evaluate the agent on test data"""
total_reward = 0
total_profit = 0
total_trades = 0
winning_trades = 0
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
initial_balance = env.balance
done = False
while not done:
# Select action (no exploration)
action = agent.select_action(state, training=False)
next_state, reward, done = env.step(action)
state = next_state
episode_reward += reward
total_reward += episode_reward
total_profit += env.balance - initial_balance
# Count trades and wins
for trade in env.trades:
if 'pnl_percent' in trade:
total_trades += 1
if trade['pnl_percent'] > 0:
winning_trades += 1
# Calculate averages
avg_reward = total_reward / num_episodes
avg_profit = total_profit / num_episodes
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
f"Win Rate={win_rate:.1f}%")
return avg_reward, avg_profit, win_rate
async def test_training():
"""Test the training process with a small number of episodes"""
logger.info("Starting training tests...")
# Initialize exchange
exchange = ccxt.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True,
})
try:
# Create environment with small initial balance for testing
env = TradingEnvironment(
exchange=exchange,
symbol="ETH/USDT",
timeframe="1m",
leverage=MAX_LEVERAGE,
initial_balance=100, # Small balance for testing
demo=True # Always use demo mode for testing
)
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
# Run a few test episodes
test_episodes = 3
logger.info(f"Running {test_episodes} test episodes...")
for episode in range(test_episodes):
state = env.reset()
episode_reward = 0
done = False
step = 0
while not done and step < 100: # Limit steps for testing
# Select action
action = agent.select_action(state)
# Take action
next_state, reward, done = env.step(action)
# Store experience
agent.memory.push(state, action, reward, next_state, done)
# Learn
loss = agent.learn()
state = next_state
episode_reward += reward
step += 1
# Print progress
if step % 10 == 0:
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
# Test model saving
try:
agent.save("models/test_model.pt")
logger.info("Successfully saved model")
except Exception as e:
logger.error(f"Error saving model: {e}")
logger.info("Training tests completed successfully")
return True
except Exception as e:
logger.error(f"Training test failed: {e}")
return False
finally:
await exchange.close()
async def initialize_exchange():
"""Initialize the exchange connection"""
try:
# Try to initialize with async support first
try:
exchange = ccxt.pro.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True
})
logger.info(f"Exchange initialized with async support: {exchange.id}")
except (AttributeError, ImportError):
# Fall back to standard CCXT
exchange = ccxt.mexc({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True
})
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
return exchange
except Exception as e:
logger.error(f"Failed to initialize exchange: {e}")
raise
async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
"""Fetch historical OHLCV data from the exchange"""
try:
logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}")
# Use the refactored fetch method
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
if not data:
logger.warning("No historical data received")
return data
except Exception as e:
logger.error(f"Failed to fetch historical data: {e}")
return []
async def live_trading(agent, env, exchange, demo=True):
"""Run live trading with the trained agent"""
logger.info(f"Starting live trading (demo mode: {demo})")
try:
# Subscribe to websocket for real-time data
symbol = "ETH/USDT"
timeframe = "1m"
# Initialize with historical data
success = await env.fetch_initial_data(exchange, symbol, timeframe, 100)
if not success:
logger.error("Failed to initialize with historical data")
return
# Main trading loop
step_counter = 0
# For online learning
states = []
actions = []
rewards = []
next_states = []
dones = []
while True:
# Wait for the next candle (1 minute)
await asyncio.sleep(5) # Check every 5 seconds
# Fetch latest candle
latest_candle = await get_latest_candle(exchange, symbol)
if not latest_candle:
logger.warning("No latest candle received, skipping update")
continue
# Update environment with new data
env.add_data(latest_candle)
# Get current state
state = env.get_state()
# Select action (no exploration in live trading)
action = agent.select_action(state, training=False)
# Take action
next_state, reward, done = env.step(action)
# Store experience for online learning
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(done)
# Online learning - update the model with new experiences
if len(states) >= 10: # Batch size for online learning
# Store experiences in replay memory
for i in range(len(states)):
agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
# Learn from experiences if we have enough samples
if len(agent.memory) > 32:
loss = agent.learn()
if loss is not None:
agent.writer.add_scalar('Live/Loss', loss, step_counter)
# Clear the temporary storage
states = []
actions = []
rewards = []
next_states = []
dones = []
# Save the updated model periodically
if step_counter % 100 == 0:
agent.save("models/trading_agent_live_updated.pt")
logger.info("Updated model saved during live trading")
# Log trading activity
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
logger.info(f"Price: ${latest_candle['close']:.2f} | Action: {action_names[action]}")
# Log performance metrics
if env.trades:
wins = sum(1 for t in env.trades if t.get('pnl_percent', 0) > 0)
win_rate = wins / len(env.trades) * 100
total_pnl = sum(t.get('pnl_dollar', 0) for t in env.trades)
logger.info(f"Balance: ${env.balance:.2f} | Trades: {len(env.trades)} | "
f"Win Rate: {win_rate:.1f}% | Total PnL: ${total_pnl:.2f}")
# Analyze recent trades
trade_analysis = env.analyze_trades()
if trade_analysis:
logger.info(f"Recent Performance: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends")
# Add chart to TensorBoard periodically
step_counter += 1
if step_counter % 10 == 0: # Update chart every 10 steps
agent.add_chart_to_tensorboard(env, step_counter)
# Also log current PnL and balance
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
agent.writer.add_scalar('Live/TotalPnL', env.total_pnl, step_counter)
agent.writer.add_scalar('Live/WinRate',
(env.win_count / (env.win_count + env.loss_count) * 100)
if (env.win_count + env.loss_count) > 0 else 0,
step_counter)
except KeyboardInterrupt:
logger.info("Live trading stopped by user")
except Exception as e:
logger.error(f"Error in live trading: {e}")
raise
async def get_latest_candle(exchange, symbol):
"""Get the latest candle data"""
try:
# Use the refactored fetch method with limit=1
data = await fetch_ohlcv_data(exchange, symbol, "1m", 1)
if data and len(data) > 0:
return data[0]
else:
logger.warning("No candle data received")
return None
except Exception as e:
logger.error(f"Failed to fetch latest candle: {e}")
return None
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
"""Fetch OHLCV data with proper handling for both async and standard CCXT"""
try:
# Check if exchange has fetchOHLCV method
if not hasattr(exchange, 'fetchOHLCV'):
logger.error("Exchange does not support OHLCV data fetching")
return []
# Handle different CCXT versions
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
# Use async method if available
ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit)
else:
# Use synchronous method with run_in_executor
loop = asyncio.get_event_loop()
ohlcv = await loop.run_in_executor(
None,
lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
)
# Convert to list of dictionaries
data = []
for candle in ohlcv:
timestamp, open_price, high, low, close, volume = candle
data.append({
'timestamp': timestamp,
'open': open_price,
'high': high,
'low': low,
'close': close,
'volume': volume
})
logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})")
return data
except Exception as e:
logger.error(f"Error fetching OHLCV data: {e}")
return []
async def main():
"""Main function to run the trading bot"""
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'],
help='Mode to run the bot in')
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
args = parser.parse_args()
# Get device (GPU or CPU)
device = get_device()
exchange = None
try:
# Initialize exchange
exchange = await initialize_exchange()
# Create environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo)
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
if args.mode == 'train':
# Train the agent
logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'evaluate':
# Load trained model
agent.load("models/trading_agent_best_pnl.pt")
# Evaluate the agent
logger.info("Evaluating agent...")
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
elif args.mode == 'live':
# Load trained model
agent.load("models/trading_agent_best_pnl.pt")
# Run live trading
logger.info("Starting live trading...")
await live_trading(agent, env, exchange, demo=args.demo)
finally:
# Clean up exchange connection - safely close if possible
if exchange:
try:
# Some CCXT exchanges have close method, others don't
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
await exchange.client.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.warning(f"Could not properly close exchange connection: {e}")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program terminated by user")