2465 lines
99 KiB
Python
2465 lines
99 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
import datetime
|
|
import random
|
|
import logging
|
|
import asyncio
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from collections import deque, namedtuple
|
|
from dotenv import load_dotenv
|
|
import ccxt
|
|
import websockets
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torch.cuda.amp as amp # Add this import at the top
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
import copy
|
|
import argparse
|
|
import traceback
|
|
import io
|
|
import matplotlib.dates as mdates
|
|
from matplotlib.figure import Figure
|
|
from PIL import Image
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
|
|
)
|
|
logger = logging.getLogger("trading_bot")
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
|
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
|
|
|
# Constants
|
|
INITIAL_BALANCE = 100 # USD
|
|
MAX_LEVERAGE = 100
|
|
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
|
|
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
|
|
MEMORY_SIZE = 100000
|
|
BATCH_SIZE = 64
|
|
GAMMA = 0.99 # Discount factor
|
|
EPSILON_START = 1.0
|
|
EPSILON_END = 0.05
|
|
EPSILON_DECAY = 10000
|
|
STATE_SIZE = 40 # Size of our state representation
|
|
LEARNING_RATE = 1e-4
|
|
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
|
|
|
# Experience replay tuple
|
|
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
|
|
|
|
# Add this function near the top of the file, after the imports but before any classes
|
|
def find_local_extrema(prices, window=5):
|
|
"""Find local minima (bottoms) and maxima (tops) in price data"""
|
|
bottoms = []
|
|
tops = []
|
|
|
|
if len(prices) < window * 2 + 1:
|
|
return bottoms, tops
|
|
|
|
for i in range(window, len(prices) - window):
|
|
# Check if this is a local minimum (bottom)
|
|
if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \
|
|
all(prices[i] <= prices[i+j] for j in range(1, window+1)):
|
|
bottoms.append(i)
|
|
|
|
# Check if this is a local maximum (top)
|
|
if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \
|
|
all(prices[i] >= prices[i+j] for j in range(1, window+1)):
|
|
tops.append(i)
|
|
|
|
return bottoms, tops
|
|
|
|
class ReplayMemory:
|
|
def __init__(self, capacity):
|
|
self.memory = deque(maxlen=capacity)
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
self.memory.append(Experience(state, action, reward, next_state, done))
|
|
|
|
def sample(self, batch_size):
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
class DQN(nn.Module):
|
|
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
|
super(DQN, self).__init__()
|
|
|
|
self.state_size = state_size
|
|
self.hidden_size = hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
|
|
# Initial feature extraction
|
|
self.fc1 = nn.Linear(state_size, hidden_size)
|
|
# Use LayerNorm instead of BatchNorm for more stability with varying batch sizes
|
|
self.ln1 = nn.LayerNorm(hidden_size)
|
|
self.dropout1 = nn.Dropout(0.2)
|
|
|
|
# LSTM layer for sequential data
|
|
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2)
|
|
|
|
# Attention mechanism
|
|
self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
|
|
|
|
# Output layers with increased capacity
|
|
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
|
self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm
|
|
self.dropout2 = nn.Dropout(0.2)
|
|
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
|
|
|
|
# Dueling DQN architecture
|
|
self.value_stream = nn.Linear(hidden_size // 2, 1)
|
|
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
|
|
|
# Transformer encoder for more complex pattern recognition
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
|
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
|
|
|
def forward(self, x):
|
|
batch_size = x.size(0) if x.dim() > 1 else 1
|
|
|
|
# Ensure input has correct shape
|
|
if x.dim() == 1:
|
|
x = x.unsqueeze(0) # Add batch dimension
|
|
|
|
# Check if state size matches expected input size
|
|
if x.size(1) != self.state_size:
|
|
# Handle mismatched input by either truncating or padding
|
|
if x.size(1) > self.state_size:
|
|
x = x[:, :self.state_size] # Truncate
|
|
else:
|
|
# Pad with zeros
|
|
padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device)
|
|
x = torch.cat([x, padding], dim=1)
|
|
|
|
# Initial feature extraction
|
|
x = self.fc1(x)
|
|
x = F.relu(self.ln1(x)) # LayerNorm works with any batch size
|
|
x = self.dropout1(x)
|
|
|
|
# Reshape for LSTM
|
|
x_lstm = x.unsqueeze(1) if x.dim() == 2 else x
|
|
|
|
# Process through LSTM
|
|
lstm_out, _ = self.lstm(x_lstm)
|
|
lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1]
|
|
|
|
# Process through transformer for more complex patterns
|
|
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
|
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
|
|
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
|
|
|
|
# Combine LSTM and transformer outputs
|
|
x = lstm_out + transformer_out
|
|
|
|
# Final layers
|
|
x = self.fc2(x)
|
|
x = F.relu(self.ln2(x)) # LayerNorm works with any batch size
|
|
x = self.dropout2(x)
|
|
x = F.relu(self.fc3(x))
|
|
|
|
# Dueling architecture
|
|
value = self.value_stream(x)
|
|
advantages = self.advantage_stream(x)
|
|
qvals = value + (advantages - advantages.mean(dim=1, keepdim=True))
|
|
|
|
return qvals
|
|
|
|
class PricePredictionModel(nn.Module):
|
|
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
|
|
super(PricePredictionModel, self).__init__()
|
|
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
|
self.scaler = MinMaxScaler(feature_range=(0, 1))
|
|
self.is_fitted = False
|
|
|
|
def forward(self, x):
|
|
# x shape: [batch_size, seq_len, 1]
|
|
lstm_out, _ = self.lstm(x)
|
|
# Use the last time step output
|
|
predictions = self.fc(lstm_out[:, -1, :])
|
|
return predictions
|
|
|
|
def preprocess(self, data):
|
|
# Reshape data for scaler
|
|
data_reshaped = np.array(data).reshape(-1, 1)
|
|
|
|
# Fit scaler if not already fitted
|
|
if not self.is_fitted:
|
|
self.scaler.fit(data_reshaped)
|
|
self.is_fitted = True
|
|
|
|
# Transform data
|
|
scaled_data = self.scaler.transform(data_reshaped)
|
|
return scaled_data
|
|
|
|
def postprocess(self, scaled_predictions):
|
|
# Inverse transform to get actual price values
|
|
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
|
|
|
|
def predict_next_candles(self, price_history, num_candles=5):
|
|
if len(price_history) < 30: # Need enough history
|
|
return np.zeros(num_candles)
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequence
|
|
sequence = scaled_data[-30:].reshape(1, 30, 1)
|
|
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
|
|
|
|
# Get predictions
|
|
with torch.no_grad():
|
|
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
|
|
|
|
# Postprocess predictions
|
|
predictions = self.postprocess(scaled_predictions)
|
|
return predictions
|
|
|
|
def train_on_new_data(self, price_history, optimizer, epochs=10):
|
|
if len(price_history) < 35: # Need enough history for training
|
|
return 0.0
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequences and targets
|
|
sequences = []
|
|
targets = []
|
|
|
|
for i in range(len(scaled_data) - 35):
|
|
# Sequence: 30 time steps
|
|
seq = scaled_data[i:i+30]
|
|
# Target: next 5 time steps
|
|
target = scaled_data[i+30:i+35].flatten()
|
|
|
|
sequences.append(seq)
|
|
targets.append(target)
|
|
|
|
if not sequences: # If no sequences were created
|
|
return 0.0
|
|
|
|
# Convert to tensors
|
|
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
|
|
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
|
|
|
|
# Training loop
|
|
total_loss = 0
|
|
for _ in range(epochs):
|
|
# Forward pass
|
|
predictions = self(sequences_tensor)
|
|
|
|
# Calculate loss
|
|
loss = F.mse_loss(predictions, targets_tensor)
|
|
|
|
# Backward pass and optimize
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
return total_loss / epochs
|
|
|
|
class TradingEnvironment:
|
|
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
|
|
"""Initialize the trading environment"""
|
|
self.initial_balance = initial_balance
|
|
self.balance = initial_balance
|
|
self.window_size = window_size
|
|
self.demo = demo
|
|
self.data = []
|
|
self.position = 'flat' # 'flat', 'long', or 'short'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
self.trades = []
|
|
self.win_count = 0
|
|
self.loss_count = 0
|
|
self.total_pnl = 0.0
|
|
self.episode_pnl = 0.0
|
|
self.peak_balance = initial_balance
|
|
self.max_drawdown = 0.0
|
|
self.current_step = 0
|
|
self.current_price = 0
|
|
|
|
# For tracking signals for visualization
|
|
self.trade_signals = []
|
|
|
|
# Initialize features
|
|
self.features = {
|
|
'price': [],
|
|
'volume': [],
|
|
'rsi': [],
|
|
'macd': [],
|
|
'macd_signal': [],
|
|
'macd_hist': [],
|
|
'bollinger_upper': [],
|
|
'bollinger_mid': [],
|
|
'bollinger_lower': [],
|
|
'stoch_k': [],
|
|
'stoch_d': [],
|
|
'ema_9': [],
|
|
'ema_21': [],
|
|
'atr': []
|
|
}
|
|
|
|
# Initialize price predictor
|
|
self.price_predictor = None
|
|
self.predicted_prices = np.array([])
|
|
|
|
# Initialize optimal trade tracking
|
|
self.optimal_bottoms = []
|
|
self.optimal_tops = []
|
|
self.optimal_signals = np.array([])
|
|
|
|
# Add these new attributes
|
|
self.leverage = MAX_LEVERAGE
|
|
self.futures_symbol = "ETH_USDT" # Example futures symbol
|
|
self.position_mode = "hedge" # For simultaneous long/short positions
|
|
self.margin_mode = "cross" # Cross margin mode
|
|
|
|
def reset(self):
|
|
"""Reset the environment to initial state"""
|
|
self.balance = self.initial_balance
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
self.trades = []
|
|
self.win_count = 0
|
|
self.loss_count = 0
|
|
self.episode_pnl = 0.0
|
|
self.peak_balance = self.initial_balance
|
|
self.max_drawdown = 0.0
|
|
self.current_step = 0
|
|
|
|
# Keep data but reset current position
|
|
if len(self.data) > self.window_size:
|
|
self.current_step = self.window_size
|
|
self.current_price = self.data[self.current_step]['close']
|
|
|
|
# Reset trade signals
|
|
self.trade_signals = []
|
|
|
|
return self.get_state()
|
|
|
|
def add_data(self, candle):
|
|
"""Add a new candle to the data"""
|
|
self.data.append(candle)
|
|
self._update_features()
|
|
self.current_price = candle['close']
|
|
|
|
def _initialize_features(self):
|
|
"""Initialize technical indicators and features"""
|
|
if len(self.data) < 30:
|
|
return
|
|
|
|
# Convert data to pandas DataFrame for easier calculation
|
|
df = pd.DataFrame(self.data)
|
|
|
|
# Basic price and volume
|
|
self.features['price'] = df['close'].values
|
|
self.features['volume'] = df['volume'].values
|
|
|
|
# Calculate RSI (14 periods)
|
|
delta = df['close'].diff()
|
|
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
|
|
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
|
|
rs = gain / loss
|
|
self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values
|
|
|
|
# Calculate MACD
|
|
ema12 = df['close'].ewm(span=12, adjust=False).mean()
|
|
ema26 = df['close'].ewm(span=26, adjust=False).mean()
|
|
macd = ema12 - ema26
|
|
signal = macd.ewm(span=9, adjust=False).mean()
|
|
self.features['macd'] = macd.values
|
|
self.features['macd_signal'] = signal.values
|
|
self.features['macd_hist'] = (macd - signal).values
|
|
|
|
# Calculate Bollinger Bands
|
|
sma20 = df['close'].rolling(window=20).mean()
|
|
std20 = df['close'].rolling(window=20).std()
|
|
self.features['bollinger_upper'] = (sma20 + 2 * std20).values
|
|
self.features['bollinger_mid'] = sma20.values
|
|
self.features['bollinger_lower'] = (sma20 - 2 * std20).values
|
|
|
|
# Calculate Stochastic Oscillator
|
|
low_14 = df['low'].rolling(window=14).min()
|
|
high_14 = df['high'].rolling(window=14).max()
|
|
k = 100 * ((df['close'] - low_14) / (high_14 - low_14))
|
|
self.features['stoch_k'] = k.values
|
|
self.features['stoch_d'] = k.rolling(window=3).mean().values
|
|
|
|
# Calculate EMAs
|
|
self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values
|
|
self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values
|
|
|
|
# Calculate ATR
|
|
high_low = df['high'] - df['low']
|
|
high_close = (df['high'] - df['close'].shift()).abs()
|
|
low_close = (df['low'] - df['close'].shift()).abs()
|
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values
|
|
|
|
def _update_features(self):
|
|
"""Update technical indicators with new data"""
|
|
self._initialize_features() # Recalculate all features
|
|
|
|
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""Fetch initial historical data for the environment"""
|
|
try:
|
|
logger.info(f"Fetching initial data for {symbol}")
|
|
|
|
# Use the refactored fetch method
|
|
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
|
|
|
|
# Update environment with fetched data
|
|
if data:
|
|
self.data = data
|
|
self._initialize_features()
|
|
logger.info(f"Initialized environment with {len(data)} candles")
|
|
else:
|
|
logger.warning("No initial data received")
|
|
|
|
return len(data) > 0
|
|
except Exception as e:
|
|
logger.error(f"Error fetching initial data: {e}")
|
|
return False
|
|
|
|
def step(self, action):
|
|
"""Take an action in the environment and return the next state, reward, and done flag"""
|
|
# Store current price before taking action
|
|
self.current_price = self.data[self.current_step]['close']
|
|
|
|
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
|
reward = self.calculate_reward(action)
|
|
|
|
# Record trade signal for visualization
|
|
if action > 0: # If not HOLD
|
|
signal_type = None
|
|
if action == 1: # BUY/LONG
|
|
signal_type = 'buy'
|
|
elif action == 2: # SELL/SHORT
|
|
signal_type = 'sell'
|
|
elif action == 3: # CLOSE
|
|
if self.position == 'long':
|
|
signal_type = 'close_long'
|
|
elif self.position == 'short':
|
|
signal_type = 'close_short'
|
|
|
|
if signal_type:
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.current_price,
|
|
'type': signal_type,
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Check for stop loss / take profit hits
|
|
self.check_sl_tp()
|
|
|
|
# Move to next step
|
|
self.current_step += 1
|
|
done = self.current_step >= len(self.data) - 1
|
|
|
|
# Get new state
|
|
next_state = self.get_state()
|
|
|
|
return next_state, reward, done
|
|
|
|
def check_sl_tp(self):
|
|
"""Check if stop loss or take profit has been hit"""
|
|
if self.position == 'flat':
|
|
return
|
|
|
|
if self.position == 'long':
|
|
# Check stop loss
|
|
if self.current_price <= self.stop_loss:
|
|
# Stop loss hit
|
|
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.stop_loss,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'stop_loss'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.stop_loss,
|
|
'type': 'stop_loss_long',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Check take profit
|
|
elif self.current_price >= self.take_profit:
|
|
# Take profit hit
|
|
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.take_profit,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'take_profit'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.win_count += 1
|
|
|
|
logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.take_profit,
|
|
'type': 'take_profit_long',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
elif self.position == 'short':
|
|
# Check stop loss
|
|
if self.current_price >= self.stop_loss:
|
|
# Stop loss hit
|
|
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.stop_loss,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'stop_loss'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.stop_loss,
|
|
'type': 'stop_loss_short',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Check take profit
|
|
elif self.current_price <= self.take_profit:
|
|
# Take profit hit
|
|
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.take_profit,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'take_profit'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.win_count += 1
|
|
|
|
logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.take_profit,
|
|
'type': 'take_profit_short',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
def get_state(self):
|
|
"""Create state representation for the agent with enhanced features"""
|
|
if len(self.data) < 30 or len(self.features['price']) == 0:
|
|
# Return zeros if not enough data
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Create a normalized state vector with recent price action and indicators
|
|
state_components = []
|
|
|
|
# Price features (normalize recent prices by the latest price)
|
|
latest_price = self.features['price'][-1]
|
|
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
|
state_components.append(price_features)
|
|
|
|
# Volume features (normalize by max volume)
|
|
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
|
vol_features = np.array(self.features['volume'][-5:]) / max_vol
|
|
state_components.append(vol_features)
|
|
|
|
# Technical indicators
|
|
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
|
|
state_components.append(rsi)
|
|
|
|
# MACD (normalize)
|
|
macd_vals = np.array(self.features['macd'][-3:])
|
|
macd_signal = np.array(self.features['macd_signal'][-3:])
|
|
macd_hist = np.array(self.features['macd_hist'][-3:])
|
|
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
|
|
macd_norm = macd_vals / macd_scale
|
|
macd_signal_norm = macd_signal / macd_scale
|
|
macd_hist_norm = macd_hist / macd_scale
|
|
|
|
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
|
|
|
|
# Bollinger position (where is price relative to bands)
|
|
bb_upper = np.array(self.features['bollinger_upper'][-3:])
|
|
bb_lower = np.array(self.features['bollinger_lower'][-3:])
|
|
bb_mid = np.array(self.features['bollinger_mid'][-3:])
|
|
price = np.array(self.features['price'][-3:])
|
|
|
|
# Calculate position of price within Bollinger Bands (0 to 1)
|
|
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
|
state_components.append(np.array(bb_pos))
|
|
|
|
# Stochastic oscillator
|
|
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
|
|
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
|
|
|
|
# Add predicted prices (if available)
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Normalize predictions relative to current price
|
|
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
|
|
state_components.append(pred_norm)
|
|
else:
|
|
# Add zeros if no predictions
|
|
state_components.append(np.zeros(3))
|
|
|
|
# Add extrema signals (if available)
|
|
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
|
|
# Get recent signals
|
|
idx = len(self.optimal_signals) - 5
|
|
if idx < 0:
|
|
idx = 0
|
|
recent_signals = self.optimal_signals[idx:idx+5]
|
|
# Pad if needed
|
|
if len(recent_signals) < 5:
|
|
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
|
|
state_components.append(recent_signals)
|
|
else:
|
|
# Add zeros if no signals
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Position info
|
|
position_info = np.zeros(5)
|
|
if self.position == 'long':
|
|
position_info[0] = 1.0 # Position is long
|
|
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
elif self.position == 'short':
|
|
position_info[0] = -1.0 # Position is short
|
|
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
|
|
state_components.append(position_info)
|
|
|
|
# NEW FEATURES START HERE
|
|
|
|
# 1. Price momentum features (rate of change over different periods)
|
|
if len(self.features['price']) >= 20:
|
|
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
|
|
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
|
|
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
|
|
momentum_features = np.array([roc_5, roc_10, roc_20])
|
|
state_components.append(momentum_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 2. Volatility features
|
|
if len(self.features['price']) >= 20:
|
|
# Calculate price returns
|
|
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
|
|
# Calculate volatility (standard deviation of returns)
|
|
volatility = np.std(returns)
|
|
# Calculate normalized high-low range
|
|
high_low_range = np.mean([
|
|
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
|
|
for i in range(max(0, len(self.data)-5), len(self.data))
|
|
]) if len(self.data) > 0 else 0
|
|
# ATR normalized by price
|
|
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
|
|
|
|
volatility_features = np.array([volatility, high_low_range, atr_norm])
|
|
state_components.append(volatility_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 3. Market regime features
|
|
if len(self.features['price']) >= 50:
|
|
# Trend strength (ADX-like measure)
|
|
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
|
|
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
|
|
trend_strength = abs(ema9 - ema21) / ema21
|
|
|
|
# Detect if in range or trending
|
|
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
|
|
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
|
|
|
|
# Detect if near support/resistance
|
|
near_support = 1.0 if self.is_near_support() else 0.0
|
|
near_resistance = 1.0 if self.is_near_resistance() else 0.0
|
|
|
|
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
|
|
state_components.append(market_regime)
|
|
else:
|
|
state_components.append(np.zeros(5))
|
|
|
|
# 4. Trade history features
|
|
if len(self.trades) > 0:
|
|
# Recent win/loss ratio
|
|
recent_trades = self.trades[-min(10, len(self.trades)):]
|
|
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
|
|
|
|
# Average profit/loss
|
|
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
|
|
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
|
|
|
|
# Normalize by balance
|
|
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
|
|
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
|
|
|
|
# Last trade result
|
|
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
|
|
|
|
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
|
|
state_components.append(trade_history)
|
|
else:
|
|
state_components.append(np.zeros(4))
|
|
|
|
# Combine all features
|
|
state = np.concatenate([comp.flatten() for comp in state_components])
|
|
|
|
# Replace any NaN values
|
|
state = np.nan_to_num(state, nan=0.0)
|
|
|
|
# Return the state (the caller will handle sizing)
|
|
return state
|
|
|
|
def get_expanded_state_size(self):
|
|
"""Calculate the size of the expanded state representation"""
|
|
# Create a dummy state to get its size
|
|
state = self.get_state()
|
|
return len(state)
|
|
|
|
async def expand_model_with_new_features(agent, env):
|
|
"""Expand the model to handle new features without retraining from scratch"""
|
|
# Get the new state size
|
|
new_state_size = env.get_expanded_state_size()
|
|
|
|
# Only expand if the new state size is larger
|
|
if new_state_size > agent.state_size:
|
|
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
|
|
|
|
# Expand the model
|
|
success = agent.expand_model(
|
|
new_state_size=new_state_size,
|
|
new_hidden_size=512, # Increase hidden size for more capacity
|
|
new_lstm_layers=3, # More layers for deeper patterns
|
|
new_attention_heads=8 # More attention heads for complex relationships
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"Model successfully expanded to handle {new_state_size} features")
|
|
return True
|
|
else:
|
|
logger.error("Failed to expand model")
|
|
return False
|
|
else:
|
|
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
|
|
return True
|
|
|
|
|
|
def calculate_reward(self, action):
|
|
"""Calculate reward for the given action with improved penalties for losing trades"""
|
|
reward = 0
|
|
|
|
# Base reward for actions
|
|
if action == 0: # HOLD
|
|
reward = -0.01 # Small penalty for doing nothing
|
|
|
|
elif action == 1: # BUY/LONG
|
|
if self.position == 'flat':
|
|
# Opening a long position
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
|
|
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
|
|
|
|
# Check if this is an optimal buy point (bottom)
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
|
reward += 2.0 # Bonus for buying at a bottom
|
|
else:
|
|
# Check if we're buying in a downtrend (bad)
|
|
if self.is_downtrend():
|
|
reward -= 0.5 # Penalty for buying in downtrend
|
|
else:
|
|
reward += 0.1 # Small reward for opening a position
|
|
|
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif self.position == 'short':
|
|
# Close short and open long
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = len(self.features['price']) - self.entry_index
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'market_direction': self.get_market_direction()
|
|
})
|
|
|
|
# Reward based on PnL with stronger penalties for losses
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
# Stronger penalty for losses, scaled by the size of the loss
|
|
loss_penalty = 1.0 + abs(pnl_dollar) / 5
|
|
reward -= loss_penalty
|
|
self.loss_count += 1
|
|
|
|
# Extra penalty for closing a losing trade too quickly
|
|
if trade_duration < 5:
|
|
reward -= 0.5 # Penalty for very short losing trades
|
|
|
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Now open long
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.entry_index = len(self.features['price']) - 1
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
|
|
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
|
|
|
|
# Check if this is an optimal buy point
|
|
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
|
|
reward += 2.0 # Bonus for buying at a bottom
|
|
|
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif action == 2: # SELL/SHORT
|
|
if self.position == 'flat':
|
|
# Opening a short position
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
|
|
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
|
|
|
|
# Check if this is an optimal sell point (top)
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
|
reward += 2.0 # Bonus for selling at a top
|
|
else:
|
|
reward += 0.1 # Small reward for opening a position
|
|
|
|
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif self.position == 'long':
|
|
# Close long and open short
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Now open short
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
|
|
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
|
|
|
|
# Check if this is an optimal sell point
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
|
reward += 2.0 # Bonus for selling at a top
|
|
|
|
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif action == 3: # CLOSE
|
|
if self.position == 'long':
|
|
# Close long position
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
elif self.position == 'short':
|
|
# Close short position
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Add prediction accuracy component to reward
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Compare the first prediction with actual price
|
|
if len(self.data) > 1:
|
|
actual_price = self.data[-1]['close']
|
|
predicted_price = self.predicted_prices[0]
|
|
prediction_error = abs(predicted_price - actual_price) / actual_price
|
|
|
|
# Reward accurate predictions, penalize bad ones
|
|
if prediction_error < 0.005: # Less than 0.5% error
|
|
reward += 0.5
|
|
elif prediction_error > 0.02: # More than 2% error
|
|
reward -= 0.5
|
|
|
|
return reward
|
|
|
|
def is_downtrend(self):
|
|
"""Check if the market is in a downtrend"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Use EMA to determine trend
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
|
|
# Downtrend if short EMA is below long EMA
|
|
return short_ema < long_ema
|
|
|
|
def is_uptrend(self):
|
|
"""Check if the market is in an uptrend"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Use EMA to determine trend
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
|
|
# Uptrend if short EMA is above long EMA
|
|
return short_ema > long_ema
|
|
|
|
def get_market_direction(self):
|
|
"""Get the current market direction"""
|
|
if self.is_uptrend():
|
|
return "uptrend"
|
|
elif self.is_downtrend():
|
|
return "downtrend"
|
|
else:
|
|
return "sideways"
|
|
|
|
def analyze_trades(self):
|
|
"""Analyze completed trades to identify patterns"""
|
|
if not self.trades:
|
|
return {}
|
|
|
|
analysis = {
|
|
'total_trades': len(self.trades),
|
|
'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0),
|
|
'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0),
|
|
'avg_win': 0,
|
|
'avg_loss': 0,
|
|
'avg_duration': 0,
|
|
'uptrend_win_rate': 0,
|
|
'downtrend_win_rate': 0,
|
|
'sideways_win_rate': 0
|
|
}
|
|
|
|
# Calculate averages
|
|
wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0]
|
|
losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0]
|
|
durations = [t.get('duration', 0) for t in self.trades]
|
|
|
|
analysis['avg_win'] = sum(wins) / len(wins) if wins else 0
|
|
analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0
|
|
analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0
|
|
|
|
# Calculate win rates by market direction
|
|
for direction in ['uptrend', 'downtrend', 'sideways']:
|
|
direction_trades = [t for t in self.trades if t.get('market_direction') == direction]
|
|
if direction_trades:
|
|
wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0)
|
|
analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100
|
|
|
|
return analysis
|
|
|
|
def initialize_price_predictor(self, device="cpu"):
|
|
"""Initialize the price prediction model"""
|
|
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
|
self.price_predictor.to(device)
|
|
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
|
self.predicted_prices = np.array([])
|
|
|
|
def train_price_predictor(self):
|
|
"""Train the price prediction model on recent data"""
|
|
if len(self.features['price']) < 35:
|
|
return 0.0
|
|
|
|
# Get price history
|
|
price_history = self.features['price']
|
|
|
|
# Train the model
|
|
loss = self.price_predictor.train_on_new_data(
|
|
price_history,
|
|
self.price_predictor_optimizer,
|
|
epochs=5
|
|
)
|
|
|
|
return loss
|
|
|
|
def update_price_predictions(self):
|
|
"""Update price predictions"""
|
|
if len(self.features['price']) < 30:
|
|
self.predicted_prices = np.array([])
|
|
return
|
|
|
|
# Get price history
|
|
price_history = self.features['price']
|
|
|
|
# Get predictions
|
|
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
|
|
|
def identify_optimal_trades(self):
|
|
"""Identify optimal entry and exit points based on local extrema"""
|
|
if len(self.features['price']) < 20:
|
|
return
|
|
|
|
# Find local bottoms and tops
|
|
bottoms, tops = find_local_extrema(self.features['price'], window=5)
|
|
|
|
# Store optimal trade points
|
|
self.optimal_bottoms = bottoms # Buy points
|
|
self.optimal_tops = tops # Sell points
|
|
|
|
# Create optimal trade signals
|
|
self.optimal_signals = np.zeros(len(self.features['price']))
|
|
for i in bottoms:
|
|
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
|
|
self.optimal_signals[i] = 1 # Buy signal
|
|
for i in tops:
|
|
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
|
|
self.optimal_signals[i] = -1 # Sell signal
|
|
|
|
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
|
|
|
|
def calculate_position_size(self):
|
|
"""Calculate position size based on current balance and risk parameters"""
|
|
# Use a fixed percentage of balance for each trade
|
|
risk_percent = 5.0 # Risk 5% of balance per trade
|
|
|
|
# Calculate position size with leverage
|
|
position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE
|
|
|
|
# Apply a safety factor to avoid liquidation
|
|
safety_factor = 0.8
|
|
position_size *= safety_factor
|
|
|
|
# Ensure minimum position size
|
|
min_position = 10.0 # Minimum position size in USD
|
|
position_size = max(position_size, min(min_position, self.balance * 0.5))
|
|
|
|
# Ensure position size doesn't exceed balance * leverage
|
|
max_position = self.balance * MAX_LEVERAGE
|
|
position_size = min(position_size, max_position)
|
|
|
|
return position_size
|
|
|
|
def calculate_fees(self, position_size):
|
|
"""Calculate trading fees for a given position size"""
|
|
# Typical fee rate for crypto exchanges (0.1%)
|
|
fee_rate = 0.001
|
|
|
|
# Calculate fee
|
|
fee = position_size * fee_rate
|
|
|
|
return fee
|
|
|
|
def is_uncertain_market(self):
|
|
"""Check if the market is in an uncertain/sideways state"""
|
|
if len(self.features['price']) < 20:
|
|
return True
|
|
|
|
# Check if price is within a narrow range
|
|
recent_prices = self.features['price'][-20:]
|
|
price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices)
|
|
|
|
# Check if EMAs are close to each other
|
|
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
ema_diff = abs(short_ema - long_ema) / long_ema
|
|
|
|
# Return True if price range is small and EMAs are close
|
|
return price_range < 0.02 and ema_diff < 0.005
|
|
|
|
return price_range < 0.015 # Very narrow range
|
|
|
|
def is_near_support(self):
|
|
"""Check if current price is near a support level"""
|
|
if not hasattr(self, 'features') or len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Find recent lows
|
|
prices = self.features['price'][-30:]
|
|
lows = []
|
|
|
|
for i in range(1, len(prices)-1):
|
|
if prices[i] < prices[i-1] and prices[i] < prices[i+1]:
|
|
lows.append(prices[i])
|
|
|
|
if not lows:
|
|
return False
|
|
|
|
# Check if current price is near any of these lows
|
|
current_price = self.current_price
|
|
for low in lows:
|
|
if abs(current_price - low) / low < 0.01: # Within 1% of a recent low
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_near_resistance(self):
|
|
"""Check if current price is near a resistance level"""
|
|
if not hasattr(self, 'features') or len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Find recent highs
|
|
prices = self.features['price'][-30:]
|
|
highs = []
|
|
|
|
for i in range(1, len(prices)-1):
|
|
if prices[i] > prices[i-1] and prices[i] > prices[i+1]:
|
|
highs.append(prices[i])
|
|
|
|
if not highs:
|
|
return False
|
|
|
|
# Check if current price is near any of these highs
|
|
current_price = self.current_price
|
|
for high in highs:
|
|
if abs(current_price - high) / high < 0.01: # Within 1% of a recent high
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_market_turning(self):
|
|
"""Check if the market is potentially changing direction"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Check for divergence between price and momentum indicators
|
|
if len(self.features['rsi']) > 5:
|
|
# Price making higher highs but RSI making lower highs (bearish divergence)
|
|
price_trend = self.features['price'][-1] > self.features['price'][-5]
|
|
rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5]
|
|
|
|
if price_trend != rsi_trend:
|
|
return True
|
|
|
|
# Check for EMA crossover
|
|
if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1:
|
|
short_ema_prev = self.features['ema_9'][-2]
|
|
long_ema_prev = self.features['ema_21'][-2]
|
|
short_ema_curr = self.features['ema_9'][-1]
|
|
long_ema_curr = self.features['ema_21'][-1]
|
|
|
|
# Check if EMAs just crossed
|
|
if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \
|
|
(short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr):
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_market_against_position(self, position_type):
|
|
"""Check if market conditions have turned against the current position"""
|
|
if position_type == 'long':
|
|
# For long positions, check if market has turned bearish
|
|
return self.is_downtrend() and not self.is_near_support()
|
|
elif position_type == 'short':
|
|
# For short positions, check if market has turned bullish
|
|
return self.is_uptrend() and not self.is_near_resistance()
|
|
|
|
return False
|
|
|
|
def is_near_optimal_exit(self, position_type):
|
|
"""Check if current price is near an optimal exit point for the position"""
|
|
current_idx = len(self.features['price']) - 1
|
|
|
|
if position_type == 'long' and hasattr(self, 'optimal_tops'):
|
|
# For long positions, optimal exit is near tops
|
|
for top_idx in self.optimal_tops:
|
|
if abs(current_idx - top_idx) < 3: # Within 3 candles of a top
|
|
return True
|
|
elif position_type == 'short' and hasattr(self, 'optimal_bottoms'):
|
|
# For short positions, optimal exit is near bottoms
|
|
for bottom_idx in self.optimal_bottoms:
|
|
if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom
|
|
return True
|
|
|
|
return False
|
|
|
|
def calculate_future_profit_potential(self, position_type, lookahead=20):
|
|
"""
|
|
Calculate potential profit if position is held for a certain period
|
|
This is used for retrospective backtesting rewards
|
|
|
|
Args:
|
|
position_type: 'long' or 'short'
|
|
lookahead: Number of candles to look ahead
|
|
|
|
Returns:
|
|
Potential profit percentage
|
|
"""
|
|
if len(self.data) <= 1:
|
|
return 0
|
|
|
|
# Get current price
|
|
current_price = self.current_price
|
|
|
|
# Get future prices (if available in historical data)
|
|
future_prices = []
|
|
current_idx = self.current_step
|
|
|
|
for i in range(1, min(lookahead + 1, len(self.data) - current_idx)):
|
|
if current_idx + i < len(self.data):
|
|
future_prices.append(self.data[current_idx + i]['close'])
|
|
|
|
if not future_prices:
|
|
return 0
|
|
|
|
# Calculate potential profit
|
|
if position_type == 'long':
|
|
# For long positions, find the maximum price in the future
|
|
max_future_price = max(future_prices)
|
|
potential_profit = (max_future_price - current_price) / current_price * 100
|
|
else: # short
|
|
# For short positions, find the minimum price in the future
|
|
min_future_price = min(future_prices)
|
|
potential_profit = (current_price - min_future_price) / current_price * 100
|
|
|
|
return potential_profit
|
|
|
|
async def initialize_futures(self, exchange):
|
|
"""Initialize futures trading parameters"""
|
|
if not self.demo:
|
|
try:
|
|
# Set up futures trading parameters
|
|
await exchange.set_position_mode(True) # Hedge mode
|
|
await exchange.set_margin_mode("cross", symbol=self.futures_symbol)
|
|
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
|
|
logger.info(f"Futures initialized with {self.leverage}x leverage")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize futures: {e}")
|
|
raise
|
|
|
|
async def execute_real_trade(self, exchange, action, current_price):
|
|
"""Execute real futures trade on MEXC"""
|
|
try:
|
|
position_size = self.calculate_position_size()
|
|
|
|
if action == 1: # Open long
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='buy',
|
|
amount=position_size,
|
|
params={'positionSide': 'LONG'}
|
|
)
|
|
logger.info(f"Opened LONG position: {order}")
|
|
|
|
elif action == 2: # Open short
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='sell',
|
|
amount=position_size,
|
|
params={'positionSide': 'SHORT'}
|
|
)
|
|
logger.info(f"Opened SHORT position: {order}")
|
|
|
|
elif action == 3: # Close position
|
|
position_side = 'LONG' if self.position == 'long' else 'SHORT'
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='sell' if position_side == 'LONG' else 'buy',
|
|
amount=self.position_size,
|
|
params={'positionSide': position_side}
|
|
)
|
|
logger.info(f"Closed {position_side} position: {order}")
|
|
|
|
return order
|
|
except Exception as e:
|
|
logger.error(f"Trade execution failed: {e}")
|
|
return None
|
|
|
|
# Ensure GPU usage if available
|
|
def get_device():
|
|
"""Get the best available device (CUDA GPU or CPU)"""
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
# Set up for mixed precision training
|
|
torch.backends.cudnn.benchmark = True
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("GPU not available, using CPU")
|
|
return device
|
|
|
|
# Update Agent class to use GPU properly
|
|
class Agent:
|
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
|
device=None):
|
|
if device is None:
|
|
self.device = get_device()
|
|
else:
|
|
self.device = device
|
|
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.memory = ReplayMemory(MEMORY_SIZE)
|
|
self.steps_done = 0
|
|
self.epsilon = EPSILON_START # Initialize epsilon
|
|
|
|
# Initialize policy and target networks
|
|
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
|
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
self.target_net.eval()
|
|
|
|
# Initialize optimizer with weight decay for regularization
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
|
|
|
|
# Initialize gradient scaler for mixed precision training
|
|
self.scaler = amp.GradScaler()
|
|
|
|
# TensorBoard writer
|
|
self.writer = SummaryWriter()
|
|
|
|
# For chart visualization
|
|
self.chart_step = 0
|
|
|
|
# Create models directory if it doesn't exist
|
|
os.makedirs("models", exist_ok=True)
|
|
|
|
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
|
"""Expand the model to handle more features or increase capacity"""
|
|
logger.info(f"Expanding model: {self.state_size} → {new_state_size}, "
|
|
f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}")
|
|
|
|
# Save old weights
|
|
old_state_dict = self.policy_net.state_dict()
|
|
|
|
# Create new larger networks
|
|
new_policy_net = DQN(new_state_size, self.action_size,
|
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
|
new_target_net = DQN(new_state_size, self.action_size,
|
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
|
|
|
# Transfer weights for common layers
|
|
new_state_dict = new_policy_net.state_dict()
|
|
for name, param in old_state_dict.items():
|
|
if name in new_state_dict:
|
|
# If shapes match, copy directly
|
|
if new_state_dict[name].shape == param.shape:
|
|
new_state_dict[name] = param
|
|
# For first layer, copy weights for the original input dimensions
|
|
elif name == "fc1.weight":
|
|
new_state_dict[name][:, :self.state_size] = param
|
|
# For other layers, initialize with a strategy that preserves scale
|
|
else:
|
|
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
|
|
|
|
# Load transferred weights
|
|
new_policy_net.load_state_dict(new_state_dict)
|
|
new_target_net.load_state_dict(new_state_dict)
|
|
|
|
# Replace networks
|
|
self.policy_net = new_policy_net
|
|
self.target_net = new_target_net
|
|
self.target_net.eval()
|
|
|
|
# Update optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
|
|
|
# Update state size
|
|
self.state_size = new_state_size
|
|
|
|
# Print new model size
|
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
|
logger.info(f"New model size: {total_params:,} parameters")
|
|
|
|
return True
|
|
|
|
def select_action(self, state, training=True):
|
|
sample = random.random()
|
|
|
|
if training:
|
|
# Epsilon decay
|
|
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
|
|
np.exp(-1. * self.steps_done / EPSILON_DECAY)
|
|
self.steps_done += 1
|
|
|
|
if sample > self.epsilon or not training:
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(state).to(self.device)
|
|
action_values = self.policy_net(state_tensor)
|
|
return action_values.max(1)[1].item()
|
|
else:
|
|
return random.randrange(self.action_size)
|
|
|
|
def learn(self):
|
|
"""Learn from a batch of experiences"""
|
|
if len(self.memory) < BATCH_SIZE:
|
|
return None
|
|
|
|
try:
|
|
# Sample a batch of experiences
|
|
experiences = self.memory.sample(BATCH_SIZE)
|
|
|
|
# Convert experiences to tensors
|
|
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
|
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
|
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
|
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
|
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
|
|
|
|
# Use mixed precision for forward/backward passes
|
|
if self.device.type == "cuda":
|
|
with amp.autocast():
|
|
# Compute Q values
|
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
|
|
|
# Compute next Q values with target network
|
|
with torch.no_grad():
|
|
next_q_values = self.target_net(next_states).max(1)[0]
|
|
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
|
|
|
|
# Reshape target values to match current_q_values
|
|
target_q_values = target_q_values.unsqueeze(1)
|
|
|
|
# Compute loss
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
|
|
|
# Backward pass with mixed precision
|
|
self.optimizer.zero_grad()
|
|
self.scaler.scale(loss).backward()
|
|
|
|
# Gradient clipping to prevent exploding gradients
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
# Standard precision for CPU
|
|
# Compute Q values
|
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
|
|
|
# Compute next Q values with target network
|
|
with torch.no_grad():
|
|
next_q_values = self.target_net(next_states).max(1)[0]
|
|
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
|
|
|
|
# Reshape target values to match current_q_values
|
|
target_q_values = target_q_values.unsqueeze(1)
|
|
|
|
# Compute loss
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
|
|
|
# Backward pass
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# Gradient clipping to prevent exploding gradients
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
self.optimizer.step()
|
|
|
|
# Update steps done
|
|
self.steps_done += 1
|
|
|
|
# Update target network
|
|
if self.steps_done % TARGET_UPDATE == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
return loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during learning: {e}")
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
return None
|
|
|
|
def update_target_network(self):
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
def save(self, path="models/trading_agent.pt"):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
torch.save({
|
|
'policy_net': self.policy_net.state_dict(),
|
|
'target_net': self.target_net.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'epsilon': self.epsilon,
|
|
'steps_done': self.steps_done
|
|
}, path)
|
|
logger.info(f"Model saved to {path}")
|
|
|
|
def load(self, path="models/trading_agent.pt"):
|
|
if os.path.isfile(path):
|
|
try:
|
|
# First try with weights_only=True (safer)
|
|
checkpoint = torch.load(path, weights_only=True)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load with weights_only=True: {e}")
|
|
try:
|
|
# Try with safe_globals for numpy.scalar
|
|
import numpy as np
|
|
from torch.serialization import safe_globals
|
|
with safe_globals([np.core.multiarray.scalar]):
|
|
checkpoint = torch.load(path, weights_only=True)
|
|
except Exception as e2:
|
|
logger.warning(f"Failed with safe_globals: {e2}")
|
|
# Fall back to weights_only=False if needed
|
|
checkpoint = torch.load(path, weights_only=False)
|
|
|
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
self.epsilon = checkpoint['epsilon']
|
|
self.steps_done = checkpoint['steps_done']
|
|
logger.info(f"Model loaded from {path}")
|
|
return True
|
|
logger.warning(f"No model found at {path}")
|
|
return False
|
|
|
|
def add_chart_to_tensorboard(self, env, global_step):
|
|
"""Add enhanced trading chart to TensorBoard"""
|
|
if len(env.data) < 10: # Minimum data to show
|
|
return
|
|
|
|
try:
|
|
# Create chart with annotations
|
|
chart_img = create_candlestick_figure(
|
|
env.data,
|
|
env.trade_signals,
|
|
window_size=100,
|
|
title=f"Trading Chart (Step {global_step})"
|
|
)
|
|
|
|
# Add to TensorBoard
|
|
self.writer.add_image('Trading Chart', np.array(chart_img).transpose(2, 0, 1), global_step)
|
|
self.chart_step = global_step
|
|
|
|
# Also log position information
|
|
if env.position != 'flat':
|
|
position_info = {
|
|
'position_type': env.position,
|
|
'entry_price': env.entry_price,
|
|
'position_size': env.position_size,
|
|
'unrealized_pnl': env.total_pnl
|
|
}
|
|
self.writer.add_text('Position', str(position_info), global_step)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating chart: {e}")
|
|
|
|
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
|
"""Get live price data using websockets"""
|
|
# Connect to MEXC websocket
|
|
uri = "wss://stream.mexc.com/ws"
|
|
|
|
async with websockets.connect(uri) as websocket:
|
|
# Subscribe to kline data
|
|
subscribe_msg = {
|
|
"method": "SUBSCRIPTION",
|
|
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
|
|
}
|
|
await websocket.send(json.dumps(subscribe_msg))
|
|
|
|
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
|
|
|
|
while True:
|
|
try:
|
|
response = await websocket.recv()
|
|
data = json.loads(response)
|
|
|
|
if 'data' in data:
|
|
kline = data['data']
|
|
candle = {
|
|
'timestamp': kline['t'],
|
|
'open': float(kline['o']),
|
|
'high': float(kline['h']),
|
|
'low': float(kline['l']),
|
|
'close': float(kline['c']),
|
|
'volume': float(kline['v'])
|
|
}
|
|
yield candle
|
|
|
|
except Exception as e:
|
|
logger.error(f"Websocket error: {e}")
|
|
# Try to reconnect
|
|
await asyncio.sleep(5)
|
|
break
|
|
|
|
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
|
|
"""Train the agent using historical and live data with GPU acceleration"""
|
|
logger.info(f"Starting training on device: {agent.device}")
|
|
|
|
stats = {
|
|
'episode_rewards': [],
|
|
'episode_lengths': [],
|
|
'balances': [],
|
|
'win_rates': [],
|
|
'episode_pnls': [],
|
|
'cumulative_pnl': [],
|
|
'drawdowns': [],
|
|
'prediction_accuracy': [],
|
|
'trade_analysis': []
|
|
}
|
|
|
|
best_reward = -float('inf')
|
|
best_pnl = -float('inf')
|
|
|
|
try:
|
|
# Initialize price predictor
|
|
env.initialize_price_predictor(agent.device)
|
|
|
|
for episode in range(num_episodes):
|
|
try:
|
|
# Reset environment
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
env.episode_pnl = 0.0 # Reset episode PnL
|
|
|
|
# Identify optimal trade points for this episode
|
|
env.identify_optimal_trades()
|
|
|
|
# Train price predictor
|
|
prediction_loss = env.train_price_predictor()
|
|
|
|
# Update price predictions
|
|
env.update_price_predictions()
|
|
|
|
for step in range(max_steps_per_episode):
|
|
# Select action
|
|
action = agent.select_action(state)
|
|
|
|
# Take action
|
|
next_state, reward, done = env.step(action)
|
|
|
|
# Store experience
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
|
|
# Learn from experience with mixed precision
|
|
try:
|
|
loss = agent.learn()
|
|
if loss is not None:
|
|
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
|
|
except Exception as e:
|
|
logger.error(f"Learning error in episode {episode}, step {step}: {e}")
|
|
|
|
# Update price predictions periodically
|
|
if step % 10 == 0:
|
|
env.update_price_predictions()
|
|
|
|
# Add chart to TensorBoard periodically
|
|
if step % 50 == 0 or (step == max_steps_per_episode - 1) or done:
|
|
global_step = episode * max_steps_per_episode + step
|
|
agent.add_chart_to_tensorboard(env, global_step)
|
|
|
|
if done:
|
|
break
|
|
|
|
# Update target network
|
|
if episode % TARGET_UPDATE == 0:
|
|
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
|
|
|
# Calculate win rate
|
|
if len(env.trades) > 0:
|
|
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
|
win_rate = wins / len(env.trades) * 100
|
|
else:
|
|
win_rate = 0
|
|
|
|
# Analyze trades
|
|
trade_analysis = env.analyze_trades()
|
|
stats['trade_analysis'].append(trade_analysis)
|
|
|
|
# Calculate prediction accuracy
|
|
prediction_accuracy = 0.0
|
|
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
|
if len(env.data) > 5:
|
|
actual_prices = [candle['close'] for candle in env.data[-5:]]
|
|
predicted = env.predicted_prices[:min(5, len(actual_prices))]
|
|
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
|
|
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
|
|
|
|
# Log statistics
|
|
stats['episode_rewards'].append(episode_reward)
|
|
stats['episode_lengths'].append(step + 1)
|
|
stats['balances'].append(env.balance)
|
|
stats['win_rates'].append(win_rate)
|
|
stats['episode_pnls'].append(env.episode_pnl)
|
|
stats['cumulative_pnl'].append(env.total_pnl)
|
|
stats['drawdowns'].append(env.max_drawdown * 100)
|
|
stats['prediction_accuracy'].append(prediction_accuracy)
|
|
|
|
# Log detailed trade analysis
|
|
if trade_analysis:
|
|
logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
|
|
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | "
|
|
f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}")
|
|
|
|
# Log to TensorBoard
|
|
agent.writer.add_scalar('Reward/train', episode_reward, episode)
|
|
agent.writer.add_scalar('Balance/train', env.balance, episode)
|
|
agent.writer.add_scalar('WinRate/train', win_rate, episode)
|
|
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
|
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
|
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
|
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
|
agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
|
|
|
|
# Add final chart for this episode
|
|
agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode)
|
|
|
|
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
|
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
|
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
|
|
f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%")
|
|
|
|
# Save best model by reward
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
agent.save("models/trading_agent_best_reward.pt")
|
|
|
|
# Save best model by PnL
|
|
if env.episode_pnl > best_pnl:
|
|
best_pnl = env.episode_pnl
|
|
agent.save("models/trading_agent_best_pnl.pt")
|
|
|
|
# Save checkpoint
|
|
if episode % 10 == 0:
|
|
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in episode {episode}: {e}")
|
|
continue
|
|
|
|
# Save final model
|
|
agent.save("models/trading_agent_final.pt")
|
|
|
|
# Plot training results
|
|
plot_training_results(stats)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {e}")
|
|
raise
|
|
|
|
return stats
|
|
|
|
def plot_training_results(stats):
|
|
"""Plot detailed training results"""
|
|
plt.figure(figsize=(20, 15))
|
|
|
|
# Plot rewards
|
|
plt.subplot(3, 2, 1)
|
|
plt.plot(stats['episode_rewards'])
|
|
plt.title('Episode Rewards')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Reward')
|
|
|
|
# Plot balance
|
|
plt.subplot(3, 2, 2)
|
|
plt.plot(stats['balances'])
|
|
plt.title('Account Balance')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Balance ($)')
|
|
|
|
# Plot win rate
|
|
plt.subplot(3, 2, 3)
|
|
plt.plot(stats['win_rates'])
|
|
plt.title('Win Rate')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Win Rate (%)')
|
|
|
|
# Plot episode PnL
|
|
plt.subplot(3, 2, 4)
|
|
plt.plot(stats['episode_pnls'])
|
|
plt.title('Episode PnL')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('PnL ($)')
|
|
|
|
# Plot cumulative PnL
|
|
plt.subplot(3, 2, 5)
|
|
plt.plot(stats['cumulative_pnl'])
|
|
plt.title('Cumulative PnL')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Cumulative PnL ($)')
|
|
|
|
# Plot drawdown
|
|
plt.subplot(3, 2, 6)
|
|
plt.plot(stats['drawdowns'])
|
|
plt.title('Maximum Drawdown')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Drawdown (%)')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('training_results.png')
|
|
|
|
# Save statistics to CSV
|
|
df = pd.DataFrame(stats)
|
|
df.to_csv('training_stats.csv', index=False)
|
|
|
|
logger.info("Training statistics saved to training_stats.csv and training_results.png")
|
|
|
|
def evaluate_agent(agent, env, num_episodes=10):
|
|
"""Evaluate the agent on test data"""
|
|
total_reward = 0
|
|
total_profit = 0
|
|
total_trades = 0
|
|
winning_trades = 0
|
|
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
initial_balance = env.balance
|
|
|
|
done = False
|
|
while not done:
|
|
# Select action (no exploration)
|
|
action = agent.select_action(state, training=False)
|
|
next_state, reward, done = env.step(action)
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
|
|
total_reward += episode_reward
|
|
total_profit += env.balance - initial_balance
|
|
|
|
# Count trades and wins
|
|
for trade in env.trades:
|
|
if 'pnl_percent' in trade:
|
|
total_trades += 1
|
|
if trade['pnl_percent'] > 0:
|
|
winning_trades += 1
|
|
|
|
# Calculate averages
|
|
avg_reward = total_reward / num_episodes
|
|
avg_profit = total_profit / num_episodes
|
|
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
|
|
|
|
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
|
|
f"Win Rate={win_rate:.1f}%")
|
|
|
|
return avg_reward, avg_profit, win_rate
|
|
|
|
async def test_training():
|
|
"""Test the training process with a small number of episodes"""
|
|
logger.info("Starting training tests...")
|
|
|
|
# Initialize exchange
|
|
exchange = ccxt.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True,
|
|
})
|
|
|
|
try:
|
|
# Create environment with small initial balance for testing
|
|
env = TradingEnvironment(
|
|
exchange=exchange,
|
|
symbol="ETH/USDT",
|
|
timeframe="1m",
|
|
leverage=MAX_LEVERAGE,
|
|
initial_balance=100, # Small balance for testing
|
|
demo=True # Always use demo mode for testing
|
|
)
|
|
|
|
# Fetch initial data
|
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
|
|
|
# Create agent
|
|
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
|
|
|
|
# Run a few test episodes
|
|
test_episodes = 3
|
|
logger.info(f"Running {test_episodes} test episodes...")
|
|
|
|
for episode in range(test_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
done = False
|
|
step = 0
|
|
|
|
while not done and step < 100: # Limit steps for testing
|
|
# Select action
|
|
action = agent.select_action(state)
|
|
|
|
# Take action
|
|
next_state, reward, done = env.step(action)
|
|
|
|
# Store experience
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
# Learn
|
|
loss = agent.learn()
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
step += 1
|
|
|
|
# Print progress
|
|
if step % 10 == 0:
|
|
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
|
|
|
|
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
|
|
|
|
# Test model saving
|
|
try:
|
|
agent.save("models/test_model.pt")
|
|
logger.info("Successfully saved model")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
|
|
logger.info("Training tests completed successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training test failed: {e}")
|
|
return False
|
|
|
|
finally:
|
|
await exchange.close()
|
|
|
|
async def initialize_exchange():
|
|
"""Initialize the exchange connection"""
|
|
try:
|
|
# Try to initialize with async support first
|
|
try:
|
|
exchange = ccxt.pro.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True
|
|
})
|
|
logger.info(f"Exchange initialized with async support: {exchange.id}")
|
|
except (AttributeError, ImportError):
|
|
# Fall back to standard CCXT
|
|
exchange = ccxt.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True
|
|
})
|
|
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
|
|
|
|
return exchange
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize exchange: {e}")
|
|
raise
|
|
|
|
async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""Fetch historical OHLCV data from the exchange"""
|
|
try:
|
|
logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}")
|
|
|
|
# Use the refactored fetch method
|
|
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
|
|
|
|
if not data:
|
|
logger.warning("No historical data received")
|
|
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch historical data: {e}")
|
|
return []
|
|
|
|
async def live_trading(agent, env, exchange, demo=True):
|
|
"""Run live trading with the trained agent"""
|
|
logger.info(f"Starting live trading (demo mode: {demo})")
|
|
|
|
try:
|
|
# Subscribe to websocket for real-time data
|
|
symbol = "ETH/USDT"
|
|
timeframe = "1m"
|
|
|
|
# Initialize with historical data
|
|
success = await env.fetch_initial_data(exchange, symbol, timeframe, 100)
|
|
if not success:
|
|
logger.error("Failed to initialize with historical data")
|
|
return
|
|
|
|
# Main trading loop
|
|
step_counter = 0
|
|
|
|
# For online learning
|
|
states = []
|
|
actions = []
|
|
rewards = []
|
|
next_states = []
|
|
dones = []
|
|
|
|
while True:
|
|
# Wait for the next candle (1 minute)
|
|
await asyncio.sleep(5) # Check every 5 seconds
|
|
|
|
# Fetch latest candle
|
|
latest_candle = await get_latest_candle(exchange, symbol)
|
|
|
|
if not latest_candle:
|
|
logger.warning("No latest candle received, skipping update")
|
|
continue
|
|
|
|
# Update environment with new data
|
|
env.add_data(latest_candle)
|
|
|
|
# Get current state
|
|
state = env.get_state()
|
|
|
|
# Select action (no exploration in live trading)
|
|
action = agent.select_action(state, training=False)
|
|
|
|
# Take action
|
|
next_state, reward, done = env.step(action)
|
|
|
|
# Store experience for online learning
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_states.append(next_state)
|
|
dones.append(done)
|
|
|
|
# Online learning - update the model with new experiences
|
|
if len(states) >= 10: # Batch size for online learning
|
|
# Store experiences in replay memory
|
|
for i in range(len(states)):
|
|
agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
|
|
|
|
# Learn from experiences if we have enough samples
|
|
if len(agent.memory) > 32:
|
|
loss = agent.learn()
|
|
if loss is not None:
|
|
agent.writer.add_scalar('Live/Loss', loss, step_counter)
|
|
|
|
# Clear the temporary storage
|
|
states = []
|
|
actions = []
|
|
rewards = []
|
|
next_states = []
|
|
dones = []
|
|
|
|
# Save the updated model periodically
|
|
if step_counter % 100 == 0:
|
|
agent.save("models/trading_agent_live_updated.pt")
|
|
logger.info("Updated model saved during live trading")
|
|
|
|
# Log trading activity
|
|
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
|
|
logger.info(f"Price: ${latest_candle['close']:.2f} | Action: {action_names[action]}")
|
|
|
|
# Log performance metrics
|
|
if env.trades:
|
|
wins = sum(1 for t in env.trades if t.get('pnl_percent', 0) > 0)
|
|
win_rate = wins / len(env.trades) * 100
|
|
total_pnl = sum(t.get('pnl_dollar', 0) for t in env.trades)
|
|
|
|
logger.info(f"Balance: ${env.balance:.2f} | Trades: {len(env.trades)} | "
|
|
f"Win Rate: {win_rate:.1f}% | Total PnL: ${total_pnl:.2f}")
|
|
|
|
# Analyze recent trades
|
|
trade_analysis = env.analyze_trades()
|
|
if trade_analysis:
|
|
logger.info(f"Recent Performance: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
|
|
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends")
|
|
|
|
# Add chart to TensorBoard periodically
|
|
step_counter += 1
|
|
if step_counter % 10 == 0: # Update chart every 10 steps
|
|
agent.add_chart_to_tensorboard(env, step_counter)
|
|
|
|
# Also log current PnL and balance
|
|
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
|
agent.writer.add_scalar('Live/TotalPnL', env.total_pnl, step_counter)
|
|
agent.writer.add_scalar('Live/WinRate',
|
|
(env.win_count / (env.win_count + env.loss_count) * 100)
|
|
if (env.win_count + env.loss_count) > 0 else 0,
|
|
step_counter)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Live trading stopped by user")
|
|
except Exception as e:
|
|
logger.error(f"Error in live trading: {e}")
|
|
raise
|
|
|
|
async def get_latest_candle(exchange, symbol):
|
|
"""Get the latest candle data"""
|
|
try:
|
|
# Use the refactored fetch method with limit=1
|
|
data = await fetch_ohlcv_data(exchange, symbol, "1m", 1)
|
|
|
|
if data and len(data) > 0:
|
|
return data[0]
|
|
else:
|
|
logger.warning("No candle data received")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch latest candle: {e}")
|
|
return None
|
|
|
|
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
|
"""Fetch OHLCV data with proper handling for both async and standard CCXT"""
|
|
try:
|
|
# Check if exchange has fetchOHLCV method
|
|
if not hasattr(exchange, 'fetchOHLCV'):
|
|
logger.error("Exchange does not support OHLCV data fetching")
|
|
return []
|
|
|
|
# Handle different CCXT versions
|
|
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
|
|
# Use async method if available
|
|
ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit)
|
|
else:
|
|
# Use synchronous method with run_in_executor
|
|
loop = asyncio.get_event_loop()
|
|
ohlcv = await loop.run_in_executor(
|
|
None,
|
|
lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
|
)
|
|
|
|
# Convert to list of dictionaries
|
|
data = []
|
|
for candle in ohlcv:
|
|
timestamp, open_price, high, low, close, volume = candle
|
|
data.append({
|
|
'timestamp': timestamp,
|
|
'open': open_price,
|
|
'high': high,
|
|
'low': low,
|
|
'close': close,
|
|
'volume': volume
|
|
})
|
|
|
|
logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})")
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching OHLCV data: {e}")
|
|
return []
|
|
|
|
async def main():
|
|
"""Main function to run the trading bot"""
|
|
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
|
|
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'],
|
|
help='Mode to run the bot in')
|
|
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
|
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
|
args = parser.parse_args()
|
|
|
|
# Get device (GPU or CPU)
|
|
device = get_device()
|
|
|
|
exchange = None
|
|
|
|
try:
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
|
|
# Create environment
|
|
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo)
|
|
|
|
# Fetch initial data
|
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
|
|
|
# Create agent
|
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
|
|
|
if args.mode == 'train':
|
|
# Train the agent
|
|
logger.info(f"Starting training for {args.episodes} episodes...")
|
|
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
|
|
|
elif args.mode == 'evaluate':
|
|
# Load trained model
|
|
agent.load("models/trading_agent_best_pnl.pt")
|
|
|
|
# Evaluate the agent
|
|
logger.info("Evaluating agent...")
|
|
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
|
|
|
elif args.mode == 'live':
|
|
# Load trained model
|
|
agent.load("models/trading_agent_best_pnl.pt")
|
|
|
|
# Run live trading
|
|
logger.info("Starting live trading...")
|
|
await live_trading(agent, env, exchange, demo=args.demo)
|
|
|
|
finally:
|
|
# Clean up exchange connection - safely close if possible
|
|
if exchange:
|
|
try:
|
|
# Some CCXT exchanges have close method, others don't
|
|
if hasattr(exchange, 'close'):
|
|
await exchange.close()
|
|
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
|
await exchange.client.close()
|
|
logger.info("Exchange connection closed")
|
|
except Exception as e:
|
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Program terminated by user") |