4844 lines
203 KiB
Python
4844 lines
203 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
import random
|
|
import logging
|
|
import asyncio
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from collections import deque, namedtuple
|
|
from dotenv import load_dotenv
|
|
import ccxt
|
|
import websockets
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torch.cuda.amp as amp # Add this import at the top
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
import copy
|
|
import argparse
|
|
import traceback
|
|
import io
|
|
import matplotlib.dates as mdates
|
|
from matplotlib.figure import Figure
|
|
from PIL import Image
|
|
import matplotlib.pyplot as mpf
|
|
import matplotlib.gridspec as gridspec
|
|
import datetime
|
|
from realtime import BinanceWebSocket, BinanceHistoricalData
|
|
from datetime import datetime as dt
|
|
# Add Dash-related imports
|
|
import dash
|
|
from dash import html, dcc, callback_context
|
|
from dash.dependencies import Input, Output, State
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
from threading import Thread
|
|
import socket
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
|
|
)
|
|
logger = logging.getLogger("trading_bot")
|
|
|
|
# Look for WebSocket specific logger
|
|
websocket_logger = logging.getLogger('websocket') # or similar name
|
|
websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO
|
|
|
|
# Add this somewhere after the logger is defined
|
|
class WebSocketFilter(logging.Filter):
|
|
def filter(self, record):
|
|
# Filter out DEBUG messages from WebSocket-related modules
|
|
if record.levelno == logging.DEBUG and ('websocket' in record.name or
|
|
'protocol' in record.name or
|
|
'realtime' in record.name):
|
|
return False
|
|
return True
|
|
|
|
logger.addFilter(WebSocketFilter())
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
|
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
|
|
|
# Constants
|
|
INITIAL_BALANCE = 100 # USD
|
|
MAX_LEVERAGE = 100
|
|
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
|
|
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
|
|
MEMORY_SIZE = 100000
|
|
BATCH_SIZE = 64
|
|
GAMMA = 0.99 # Discount factor
|
|
EPSILON_START = 1.0
|
|
EPSILON_END = 0.05
|
|
EPSILON_DECAY = 10000
|
|
STATE_SIZE = 64 # Size of our state representation
|
|
LEARNING_RATE = 1e-4
|
|
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
|
|
|
# Experience replay tuple
|
|
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
|
|
|
|
# Add this function near the top of the file, after the imports but before any classes
|
|
def find_local_extrema(prices, window=5):
|
|
"""Find local minima (bottoms) and maxima (tops) in price data"""
|
|
bottoms = []
|
|
tops = []
|
|
|
|
if len(prices) < window * 2 + 1:
|
|
return bottoms, tops
|
|
|
|
for i in range(window, len(prices) - window):
|
|
# Check if this is a local minimum (bottom)
|
|
if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \
|
|
all(prices[i] <= prices[i+j] for j in range(1, window+1)):
|
|
bottoms.append(i)
|
|
|
|
# Check if this is a local maximum (top)
|
|
if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \
|
|
all(prices[i] >= prices[i+j] for j in range(1, window+1)):
|
|
tops.append(i)
|
|
|
|
return bottoms, tops
|
|
|
|
class ReplayMemory:
|
|
def __init__(self, capacity):
|
|
self.memory = deque(maxlen=capacity)
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
self.memory.append(Experience(state, action, reward, next_state, done))
|
|
|
|
def sample(self, batch_size):
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
class DQN(nn.Module):
|
|
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
|
super(DQN, self).__init__()
|
|
|
|
self.state_size = state_size
|
|
self.hidden_size = hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
|
|
# Initial feature extraction
|
|
self.fc1 = nn.Linear(state_size, hidden_size)
|
|
# Use LayerNorm instead of BatchNorm for more stability with varying batch sizes
|
|
self.ln1 = nn.LayerNorm(hidden_size)
|
|
self.dropout1 = nn.Dropout(0.2)
|
|
|
|
# LSTM layer for sequential data
|
|
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2)
|
|
|
|
# Attention mechanism
|
|
self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
|
|
|
|
# Output layers with increased capacity
|
|
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
|
self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm
|
|
self.dropout2 = nn.Dropout(0.2)
|
|
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
|
|
|
|
# Dueling DQN architecture
|
|
self.value_stream = nn.Linear(hidden_size // 2, 1)
|
|
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
|
|
|
# Transformer encoder for more complex pattern recognition
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
|
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
|
|
|
def forward(self, x):
|
|
batch_size = x.size(0) if x.dim() > 1 else 1
|
|
|
|
# Ensure input has correct shape
|
|
if x.dim() == 1:
|
|
x = x.unsqueeze(0) # Add batch dimension
|
|
|
|
# Check if state size matches expected input size
|
|
if x.size(1) != self.state_size:
|
|
# Handle mismatched input by either truncating or padding
|
|
if x.size(1) > self.state_size:
|
|
x = x[:, :self.state_size] # Truncate
|
|
else:
|
|
# Pad with zeros
|
|
padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device)
|
|
x = torch.cat([x, padding], dim=1)
|
|
|
|
# Initial feature extraction
|
|
x = self.fc1(x)
|
|
x = F.relu(self.ln1(x)) # LayerNorm works with any batch size
|
|
x = self.dropout1(x)
|
|
|
|
# Reshape for LSTM
|
|
x_lstm = x.unsqueeze(1) if x.dim() == 2 else x
|
|
|
|
# Process through LSTM
|
|
lstm_out, _ = self.lstm(x_lstm)
|
|
lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1]
|
|
|
|
# Process through transformer for more complex patterns
|
|
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
|
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
|
|
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
|
|
|
|
# Combine LSTM and transformer outputs
|
|
x = lstm_out + transformer_out
|
|
|
|
# Final layers
|
|
x = self.fc2(x)
|
|
x = F.relu(self.ln2(x)) # LayerNorm works with any batch size
|
|
x = self.dropout2(x)
|
|
x = F.relu(self.fc3(x))
|
|
|
|
# Dueling architecture
|
|
value = self.value_stream(x)
|
|
advantages = self.advantage_stream(x)
|
|
qvals = value + (advantages - advantages.mean(dim=1, keepdim=True))
|
|
|
|
return qvals
|
|
|
|
class PricePredictionModel(nn.Module):
|
|
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
|
|
super(PricePredictionModel, self).__init__()
|
|
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
|
self.scaler = MinMaxScaler(feature_range=(0, 1))
|
|
self.is_fitted = False
|
|
|
|
def forward(self, x):
|
|
# x shape: [batch_size, seq_len, 1]
|
|
lstm_out, _ = self.lstm(x)
|
|
# Use the last time step output
|
|
predictions = self.fc(lstm_out[:, -1, :])
|
|
return predictions
|
|
|
|
def preprocess(self, data):
|
|
# Reshape data for scaler
|
|
data_reshaped = np.array(data).reshape(-1, 1)
|
|
|
|
# Fit scaler if not already fitted
|
|
if not self.is_fitted:
|
|
self.scaler.fit(data_reshaped)
|
|
self.is_fitted = True
|
|
|
|
# Transform data
|
|
scaled_data = self.scaler.transform(data_reshaped)
|
|
return scaled_data
|
|
|
|
def postprocess(self, scaled_predictions):
|
|
# Inverse transform to get actual price values
|
|
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
|
|
|
|
def predict_next_candles(self, price_history, num_candles=5):
|
|
if len(price_history) < 30: # Need enough history
|
|
return np.zeros(num_candles)
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequence
|
|
sequence = scaled_data[-30:].reshape(1, 30, 1)
|
|
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
|
|
|
|
# Get predictions
|
|
with torch.no_grad():
|
|
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
|
|
|
|
# Postprocess predictions
|
|
predictions = self.postprocess(scaled_predictions)
|
|
return predictions
|
|
|
|
def train_on_new_data(self, price_history, optimizer, epochs=10):
|
|
if len(price_history) < 35: # Need enough history for training
|
|
return 0.0
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequences and targets
|
|
sequences = []
|
|
targets = []
|
|
|
|
for i in range(len(scaled_data) - 35):
|
|
# Sequence: 30 time steps
|
|
seq = scaled_data[i:i+30]
|
|
# Target: next 5 time steps
|
|
target = scaled_data[i+30:i+35].flatten()
|
|
|
|
sequences.append(seq)
|
|
targets.append(target)
|
|
|
|
if not sequences: # If no sequences were created
|
|
return 0.0
|
|
|
|
# Convert to tensors
|
|
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
|
|
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
|
|
|
|
# Training loop
|
|
total_loss = 0
|
|
for _ in range(epochs):
|
|
# Forward pass
|
|
predictions = self(sequences_tensor)
|
|
|
|
# Calculate loss
|
|
loss = F.mse_loss(predictions, targets_tensor)
|
|
|
|
# Backward pass and optimize
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
return total_loss / epochs
|
|
|
|
import os
|
|
import time
|
|
import logging
|
|
import sys
|
|
import argparse
|
|
import json
|
|
|
|
# Add the NN directory to the Python path
|
|
sys.path.append(os.path.abspath("NN"))
|
|
|
|
from NN.main import load_model
|
|
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
|
|
from NN.realtime_data_interface import RealtimeDataInterface
|
|
|
|
# Initialize logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("trading_bot.log"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def main():
|
|
"""Main function for the trading bot."""
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
|
|
parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"],
|
|
help='Trading symbols to monitor')
|
|
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
|
|
help='Timeframes to monitor')
|
|
parser.add_argument('--window-size', type=int, default=20,
|
|
help='Window size for model input')
|
|
parser.add_argument('--output-size', type=int, default=3,
|
|
help='Output size of the model (3 for BUY/HOLD/SELL)')
|
|
parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"],
|
|
help='Type of neural network model')
|
|
parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"],
|
|
help='Trading mode')
|
|
parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"],
|
|
help='Exchange to use for trading')
|
|
parser.add_argument('--api-key', type=str, default=None,
|
|
help='API key for the exchange')
|
|
parser.add_argument('--api-secret', type=str, default=None,
|
|
help='API secret for the exchange')
|
|
parser.add_argument('--test-mode', action='store_true',
|
|
help='Use test/sandbox exchange environment')
|
|
parser.add_argument('--position-size', type=float, default=0.1,
|
|
help='Position size as a fraction of total balance (0.0-1.0)')
|
|
parser.add_argument('--max-trades-per-day', type=int, default=5,
|
|
help='Maximum number of trades per day')
|
|
parser.add_argument('--trade-cooldown', type=int, default=60,
|
|
help='Trade cooldown period in minutes')
|
|
parser.add_argument('--config-file', type=str, default=None,
|
|
help='Path to configuration file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load configuration from file if provided
|
|
if args.config_file and os.path.exists(args.config_file):
|
|
with open(args.config_file, 'r') as f:
|
|
config = json.load(f)
|
|
# Override config with command-line args
|
|
for key, value in vars(args).items():
|
|
if key != 'config_file' and value is not None:
|
|
config[key] = value
|
|
else:
|
|
# Use command-line args as config
|
|
config = vars(args)
|
|
|
|
# Initialize real-time charts and data interfaces
|
|
try:
|
|
from realtime import RealTimeChart
|
|
|
|
# Create a real-time chart for each symbol
|
|
charts = {}
|
|
for symbol in config['symbols']:
|
|
charts[symbol] = RealTimeChart(symbol=symbol)
|
|
|
|
main_chart = charts[config['symbols'][0]]
|
|
|
|
# Create a data interface for retrieving market data
|
|
data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart)
|
|
|
|
# Load trained model
|
|
model_type = os.environ.get("NN_MODEL_TYPE", config['model_type'])
|
|
model = load_model(
|
|
model_type=model_type,
|
|
input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV)
|
|
output_size=config['output_size']
|
|
)
|
|
|
|
# Configure trading agent
|
|
exchange_config = {
|
|
"exchange": config['exchange'],
|
|
"api_key": config['api_key'],
|
|
"api_secret": config['api_secret'],
|
|
"test_mode": config['test_mode'],
|
|
"trade_symbols": config['symbols'],
|
|
"position_size": config['position_size'],
|
|
"max_trades_per_day": config['max_trades_per_day'],
|
|
"trade_cooldown_minutes": config['trade_cooldown']
|
|
}
|
|
|
|
# Initialize neural network orchestrator
|
|
orchestrator = NeuralNetworkOrchestrator(
|
|
model=model,
|
|
data_interface=data_interface,
|
|
chart=main_chart,
|
|
symbols=config['symbols'],
|
|
timeframes=config['timeframes'],
|
|
window_size=config['window_size'],
|
|
num_features=5, # OHLCV
|
|
output_size=config['output_size'],
|
|
exchange_config=exchange_config
|
|
)
|
|
|
|
# Start data collection
|
|
logger.info("Starting data collection threads...")
|
|
for symbol in config['symbols']:
|
|
charts[symbol].start()
|
|
|
|
# Start neural network inference
|
|
if os.environ.get("ENABLE_NN_MODELS", "0") == "1":
|
|
logger.info("Starting neural network inference...")
|
|
orchestrator.start_inference()
|
|
else:
|
|
logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.")
|
|
|
|
# Start web servers for chart display
|
|
logger.info("Starting web servers for chart display...")
|
|
main_chart.start_server()
|
|
|
|
logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("Keyboard interrupt received. Shutting down...")
|
|
# Stop all threads
|
|
for symbol in config['symbols']:
|
|
charts[symbol].stop()
|
|
orchestrator.stop_inference()
|
|
logger.info("Trading bot stopped.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main function: {str(e)}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
def get_state(self):
|
|
"""Create state representation for the agent with enhanced features"""
|
|
# Ensure we have enough data
|
|
if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0:
|
|
# Return zeros if not enough data
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Create a normalized state vector with recent price action and indicators
|
|
state_components = []
|
|
|
|
# Safely get the latest price
|
|
try:
|
|
latest_price = self.features['price'][-1]
|
|
except IndexError:
|
|
# If we can't get the latest price, return zeros
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Safely get price features
|
|
try:
|
|
# Price features (normalize recent prices by the latest price)
|
|
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
|
state_components.append(price_features)
|
|
except (IndexError, ZeroDivisionError):
|
|
# If we can't get price features, use zeros
|
|
state_components.append(np.zeros(10))
|
|
|
|
# Safely get volume features
|
|
try:
|
|
# Volume features (normalize by max volume)
|
|
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
|
vol_features = np.array(self.features['volume'][-5:]) / max_vol
|
|
state_components.append(vol_features)
|
|
except (IndexError, ZeroDivisionError):
|
|
# If we can't get volume features, use zeros
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Technical indicators
|
|
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
|
|
state_components.append(rsi)
|
|
|
|
# MACD (normalize)
|
|
macd_vals = np.array(self.features['macd'][-3:])
|
|
macd_signal = np.array(self.features['macd_signal'][-3:])
|
|
macd_hist = np.array(self.features['macd_hist'][-3:])
|
|
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
|
|
macd_norm = macd_vals / macd_scale
|
|
macd_signal_norm = macd_signal / macd_scale
|
|
macd_hist_norm = macd_hist / macd_scale
|
|
|
|
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
|
|
|
|
# Bollinger position (where is price relative to bands)
|
|
bb_upper = np.array(self.features['bollinger_upper'][-3:])
|
|
bb_lower = np.array(self.features['bollinger_lower'][-3:])
|
|
bb_mid = np.array(self.features['bollinger_mid'][-3:])
|
|
price = np.array(self.features['price'][-3:])
|
|
|
|
# Calculate position of price within Bollinger Bands (0 to 1)
|
|
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
|
state_components.append(np.array(bb_pos))
|
|
|
|
# Stochastic oscillator
|
|
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
|
|
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
|
|
|
|
# Add predicted prices (if available)
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Normalize predictions relative to current price
|
|
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
|
|
state_components.append(pred_norm)
|
|
else:
|
|
# Add zeros if no predictions
|
|
state_components.append(np.zeros(3))
|
|
|
|
# Add extrema signals (if available)
|
|
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
|
|
# Get recent signals
|
|
idx = len(self.optimal_signals) - 5
|
|
if idx < 0:
|
|
idx = 0
|
|
recent_signals = self.optimal_signals[idx:idx+5]
|
|
# Pad if needed
|
|
if len(recent_signals) < 5:
|
|
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
|
|
state_components.append(recent_signals)
|
|
else:
|
|
# Add zeros if no signals
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Position info
|
|
position_info = np.zeros(5)
|
|
if self.position == 'long':
|
|
position_info[0] = 1.0 # Position is long
|
|
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
elif self.position == 'short':
|
|
position_info[0] = -1.0 # Position is short
|
|
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
|
|
state_components.append(position_info)
|
|
|
|
# NEW FEATURES START HERE
|
|
|
|
# 1. Price momentum features (rate of change over different periods)
|
|
if len(self.features['price']) >= 20:
|
|
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
|
|
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
|
|
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
|
|
momentum_features = np.array([roc_5, roc_10, roc_20])
|
|
state_components.append(momentum_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 2. Volatility features
|
|
if len(self.features['price']) >= 20:
|
|
# Calculate price returns
|
|
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
|
|
# Calculate volatility (standard deviation of returns)
|
|
volatility = np.std(returns)
|
|
# Calculate normalized high-low range
|
|
high_low_range = np.mean([
|
|
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
|
|
for i in range(max(0, len(self.data)-5), len(self.data))
|
|
]) if len(self.data) > 0 else 0
|
|
# ATR normalized by price
|
|
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
|
|
|
|
volatility_features = np.array([volatility, high_low_range, atr_norm])
|
|
state_components.append(volatility_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 3. Market regime features
|
|
if len(self.features['price']) >= 50:
|
|
# Trend strength (ADX-like measure)
|
|
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
|
|
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
|
|
trend_strength = abs(ema9 - ema21) / ema21
|
|
|
|
# Detect if in range or trending
|
|
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
|
|
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
|
|
|
|
# Detect if near support/resistance
|
|
near_support = 1.0 if self.is_near_support() else 0.0
|
|
near_resistance = 1.0 if self.is_near_resistance() else 0.0
|
|
|
|
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
|
|
state_components.append(market_regime)
|
|
else:
|
|
state_components.append(np.zeros(5))
|
|
|
|
# 4. Trade history features
|
|
if len(self.trades) > 0:
|
|
# Recent win/loss ratio
|
|
recent_trades = self.trades[-min(10, len(self.trades)):]
|
|
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
|
|
|
|
# Average profit/loss
|
|
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
|
|
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
|
|
|
|
# Normalize by balance
|
|
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
|
|
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
|
|
|
|
# Last trade result
|
|
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
|
|
|
|
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
|
|
state_components.append(trade_history)
|
|
else:
|
|
state_components.append(np.zeros(4))
|
|
|
|
# Combine all features
|
|
state = np.concatenate([comp.flatten() for comp in state_components])
|
|
|
|
# Replace any NaN or infinite values
|
|
state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# Ensure the state has the correct size
|
|
if len(state) != STATE_SIZE:
|
|
logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}")
|
|
# Pad or truncate to match expected size
|
|
if len(state) < STATE_SIZE:
|
|
state = np.pad(state, (0, STATE_SIZE - len(state)))
|
|
else:
|
|
state = state[:STATE_SIZE]
|
|
|
|
return state
|
|
|
|
def get_expanded_state_size(self):
|
|
"""Calculate the size of the expanded state representation"""
|
|
# Create a dummy state to get its size
|
|
state = self.get_state()
|
|
return len(state)
|
|
|
|
async def expand_model_with_new_features(agent, env):
|
|
"""Expand the model to handle new features without retraining from scratch"""
|
|
# Get the new state size
|
|
new_state_size = env.get_expanded_state_size()
|
|
|
|
# Only expand if the new state size is larger
|
|
if new_state_size > agent.state_size:
|
|
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
|
|
|
|
# Expand the model
|
|
success = agent.expand_model(
|
|
new_state_size=new_state_size,
|
|
new_hidden_size=512, # Increase hidden size for more capacity
|
|
new_lstm_layers=3, # More layers for deeper patterns
|
|
new_attention_heads=8 # More attention heads for complex relationships
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"Model successfully expanded to handle {new_state_size} features")
|
|
return True
|
|
else:
|
|
logger.error("Failed to expand model")
|
|
return False
|
|
else:
|
|
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
|
|
return True
|
|
|
|
|
|
def calculate_reward(self, action):
|
|
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
|
reward = 0
|
|
|
|
# Base reward for actions
|
|
if action == 0: # HOLD
|
|
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
|
|
|
|
elif action == 1: # BUY/LONG
|
|
if self.position == 'flat':
|
|
# Opening a long position
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
# Use the adjusted risk parameters
|
|
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal buy point (bottom)
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
|
reward += 3.0 # Increased bonus for buying at a bottom
|
|
|
|
# Check for volume spike (indicating potential big movement)
|
|
if len(self.features['volume']) > 5:
|
|
avg_volume = np.mean(self.features['volume'][-5:-1])
|
|
current_volume = self.features['volume'][-1]
|
|
if current_volume > avg_volume * 1.5:
|
|
reward += 2.0 # Bonus for entering during high volume
|
|
|
|
# Check for price action signals
|
|
if self.features['rsi'][-1] < 30: # Oversold condition
|
|
reward += 1.5 # Bonus for buying at oversold levels
|
|
|
|
# Check if we're buying in a clear uptrend (good)
|
|
if self.is_uptrend():
|
|
reward += 1.0 # Bonus for buying in uptrend
|
|
elif self.is_downtrend():
|
|
reward -= 0.25 # Reduced penalty for buying in downtrend
|
|
else:
|
|
reward += 0.2 # Small reward for opening a position
|
|
|
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif self.position == 'short':
|
|
# Close short and open long
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = len(self.features['price']) - self.entry_index
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'market_direction': self.get_market_direction()
|
|
})
|
|
|
|
# Reward based on PnL with stronger penalties for losses
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
# Stronger penalty for losses, scaled by the size of the loss
|
|
loss_penalty = 1.0 + abs(pnl_dollar) / 5
|
|
reward -= loss_penalty
|
|
self.loss_count += 1
|
|
|
|
# Extra penalty for closing a losing trade too quickly
|
|
if trade_duration < 5:
|
|
reward -= 0.5 # Penalty for very short losing trades
|
|
|
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Now open long
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.entry_index = len(self.features['price']) - 1
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal buy point
|
|
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
|
|
reward += 2.0 # Bonus for buying at a bottom
|
|
|
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif action == 2: # SELL/SHORT
|
|
if self.position == 'flat':
|
|
# Opening a short position
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
# Use the adjusted risk parameters
|
|
self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 - self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal sell point (top)
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
|
reward += 3.0 # Increased bonus for selling at a top
|
|
|
|
# Check for volume spike
|
|
if len(self.features['volume']) > 5:
|
|
avg_volume = np.mean(self.features['volume'][-5:-1])
|
|
current_volume = self.features['volume'][-1]
|
|
if current_volume > avg_volume * 1.5:
|
|
reward += 2.0 # Bonus for entering during high volume
|
|
|
|
# Check for price action signals
|
|
if self.features['rsi'][-1] > 70: # Overbought condition
|
|
reward += 1.5 # Bonus for selling at overbought levels
|
|
|
|
# Check if we're selling in a clear downtrend (good)
|
|
if self.is_downtrend():
|
|
reward += 1.0 # Bonus for selling in downtrend
|
|
elif self.is_uptrend():
|
|
reward -= 0.25 # Reduced penalty for selling in uptrend
|
|
else:
|
|
reward += 0.2 # Small reward for opening a position
|
|
|
|
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif self.position == 'long':
|
|
# Close long and open short
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Now open short
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 - self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal sell point
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
|
reward += 2.0 # Bonus for selling at a top
|
|
|
|
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif action == 3: # CLOSE
|
|
if self.position == 'long':
|
|
# Close long position
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
elif self.position == 'short':
|
|
# Close short position
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Add prediction accuracy component to reward
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Compare the first prediction with actual price
|
|
if len(self.data) > 1:
|
|
actual_price = self.data[-1]['close']
|
|
predicted_price = self.predicted_prices[0]
|
|
prediction_error = abs(predicted_price - actual_price) / actual_price
|
|
|
|
# Reward accurate predictions, penalize bad ones
|
|
if prediction_error < 0.005: # Less than 0.5% error
|
|
reward += 0.5
|
|
elif prediction_error > 0.02: # More than 2% error
|
|
reward -= 0.5
|
|
|
|
return reward
|
|
|
|
def is_downtrend(self):
|
|
"""Check if the market is in a downtrend"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Use EMA to determine trend
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
|
|
# Downtrend if short EMA is below long EMA
|
|
return short_ema < long_ema
|
|
|
|
def is_uptrend(self):
|
|
"""Check if the market is in an uptrend"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Use EMA to determine trend
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
|
|
# Uptrend if short EMA is above long EMA
|
|
return short_ema > long_ema
|
|
|
|
def get_market_direction(self):
|
|
"""Get the current market direction"""
|
|
if self.is_uptrend():
|
|
return "uptrend"
|
|
elif self.is_downtrend():
|
|
return "downtrend"
|
|
else:
|
|
return "sideways"
|
|
|
|
def analyze_trades(self):
|
|
"""Analyze completed trades to identify patterns"""
|
|
if not self.trades:
|
|
return {}
|
|
|
|
analysis = {
|
|
'total_trades': len(self.trades),
|
|
'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0),
|
|
'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0),
|
|
'avg_win': 0,
|
|
'avg_loss': 0,
|
|
'avg_duration': 0,
|
|
'uptrend_win_rate': 0,
|
|
'downtrend_win_rate': 0,
|
|
'sideways_win_rate': 0
|
|
}
|
|
|
|
# Calculate averages
|
|
wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0]
|
|
losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0]
|
|
durations = [t.get('duration', 0) for t in self.trades]
|
|
|
|
analysis['avg_win'] = sum(wins) / len(wins) if wins else 0
|
|
analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0
|
|
analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0
|
|
|
|
# Calculate win rates by market direction
|
|
for direction in ['uptrend', 'downtrend', 'sideways']:
|
|
direction_trades = [t for t in self.trades if t.get('market_direction') == direction]
|
|
if direction_trades:
|
|
wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0)
|
|
analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100
|
|
|
|
return analysis
|
|
|
|
def initialize_price_predictor(self, device="cpu"):
|
|
"""Initialize the price prediction model"""
|
|
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
|
self.price_predictor.to(device)
|
|
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
|
self.predicted_prices = np.array([])
|
|
|
|
def train_price_predictor(self):
|
|
"""Train the price prediction model on recent data"""
|
|
if len(self.features['price']) < 35:
|
|
return 0.0
|
|
|
|
# Get price history
|
|
price_history = self.features['price']
|
|
|
|
# Train the model
|
|
loss = self.price_predictor.train_on_new_data(
|
|
price_history,
|
|
self.price_predictor_optimizer,
|
|
epochs=5
|
|
)
|
|
|
|
return loss
|
|
|
|
def update_price_predictions(self):
|
|
"""Update price predictions"""
|
|
if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
|
|
self.predicted_prices = np.array([])
|
|
return
|
|
|
|
# Get price history
|
|
price_history = self.features['price']
|
|
|
|
try:
|
|
# Get predictions
|
|
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
|
except Exception as e:
|
|
logger.warning(f"Error updating predictions: {e}")
|
|
self.predicted_prices = np.array([])
|
|
|
|
def identify_optimal_trades(self):
|
|
"""Identify optimal entry and exit points based on local extrema"""
|
|
if len(self.features['price']) < 20:
|
|
return
|
|
|
|
# Find local bottoms and tops
|
|
bottoms, tops = find_local_extrema(self.features['price'], window=5)
|
|
|
|
# Store optimal trade points
|
|
self.optimal_bottoms = bottoms # Buy points
|
|
self.optimal_tops = tops # Sell points
|
|
|
|
# Create optimal trade signals
|
|
self.optimal_signals = np.zeros(len(self.features['price']))
|
|
for i in bottoms:
|
|
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
|
|
self.optimal_signals[i] = 1 # Buy signal
|
|
for i in tops:
|
|
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
|
|
self.optimal_signals[i] = -1 # Sell signal
|
|
|
|
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
import random
|
|
import logging
|
|
import asyncio
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from collections import deque, namedtuple
|
|
from dotenv import load_dotenv
|
|
import ccxt
|
|
import websockets
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torch.cuda.amp as amp # Add this import at the top
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
import copy
|
|
import argparse
|
|
import traceback
|
|
import io
|
|
import matplotlib.dates as mdates
|
|
from matplotlib.figure import Figure
|
|
from PIL import Image
|
|
import matplotlib.pyplot as mpf
|
|
import matplotlib.gridspec as gridspec
|
|
import datetime
|
|
from realtime import BinanceWebSocket, BinanceHistoricalData
|
|
from datetime import datetime as dt
|
|
# Add Dash-related imports
|
|
import dash
|
|
from dash import html, dcc, callback_context
|
|
from dash.dependencies import Input, Output, State
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
from threading import Thread
|
|
import socket
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
|
|
)
|
|
logger = logging.getLogger("trading_bot")
|
|
|
|
# Look for WebSocket specific logger
|
|
websocket_logger = logging.getLogger('websocket') # or similar name
|
|
websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO
|
|
|
|
# Add this somewhere after the logger is defined
|
|
class WebSocketFilter(logging.Filter):
|
|
def filter(self, record):
|
|
# Filter out DEBUG messages from WebSocket-related modules
|
|
if record.levelno == logging.DEBUG and ('websocket' in record.name or
|
|
'protocol' in record.name or
|
|
'realtime' in record.name):
|
|
return False
|
|
return True
|
|
|
|
logger.addFilter(WebSocketFilter())
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
|
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
|
|
|
# Constants
|
|
INITIAL_BALANCE = 100 # USD
|
|
MAX_LEVERAGE = 100
|
|
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
|
|
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
|
|
MEMORY_SIZE = 100000
|
|
BATCH_SIZE = 64
|
|
GAMMA = 0.99 # Discount factor
|
|
EPSILON_START = 1.0
|
|
EPSILON_END = 0.05
|
|
EPSILON_DECAY = 10000
|
|
STATE_SIZE = 64 # Size of our state representation
|
|
LEARNING_RATE = 1e-4
|
|
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
|
|
|
# Experience replay tuple
|
|
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
|
|
|
|
# Add this function near the top of the file, after the imports but before any classes
|
|
def find_local_extrema(prices, window=5):
|
|
"""Find local minima (bottoms) and maxima (tops) in price data"""
|
|
bottoms = []
|
|
tops = []
|
|
|
|
if len(prices) < window * 2 + 1:
|
|
return bottoms, tops
|
|
|
|
for i in range(window, len(prices) - window):
|
|
# Check if this is a local minimum (bottom)
|
|
if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \
|
|
all(prices[i] <= prices[i+j] for j in range(1, window+1)):
|
|
bottoms.append(i)
|
|
|
|
# Check if this is a local maximum (top)
|
|
if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \
|
|
all(prices[i] >= prices[i+j] for j in range(1, window+1)):
|
|
tops.append(i)
|
|
|
|
return bottoms, tops
|
|
|
|
class ReplayMemory:
|
|
def __init__(self, capacity):
|
|
self.memory = deque(maxlen=capacity)
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
self.memory.append(Experience(state, action, reward, next_state, done))
|
|
|
|
def sample(self, batch_size):
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
class DQN(nn.Module):
|
|
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
|
super(DQN, self).__init__()
|
|
|
|
self.state_size = state_size
|
|
self.hidden_size = hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
|
|
# Initial feature extraction
|
|
self.fc1 = nn.Linear(state_size, hidden_size)
|
|
# Use LayerNorm instead of BatchNorm for more stability with varying batch sizes
|
|
self.ln1 = nn.LayerNorm(hidden_size)
|
|
self.dropout1 = nn.Dropout(0.2)
|
|
|
|
# LSTM layer for sequential data
|
|
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2)
|
|
|
|
# Attention mechanism
|
|
self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
|
|
|
|
# Output layers with increased capacity
|
|
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
|
self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm
|
|
self.dropout2 = nn.Dropout(0.2)
|
|
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
|
|
|
|
# Dueling DQN architecture
|
|
self.value_stream = nn.Linear(hidden_size // 2, 1)
|
|
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
|
|
|
# Transformer encoder for more complex pattern recognition
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
|
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
|
|
|
def forward(self, x):
|
|
batch_size = x.size(0) if x.dim() > 1 else 1
|
|
|
|
# Ensure input has correct shape
|
|
if x.dim() == 1:
|
|
x = x.unsqueeze(0) # Add batch dimension
|
|
|
|
# Check if state size matches expected input size
|
|
if x.size(1) != self.state_size:
|
|
# Handle mismatched input by either truncating or padding
|
|
if x.size(1) > self.state_size:
|
|
x = x[:, :self.state_size] # Truncate
|
|
else:
|
|
# Pad with zeros
|
|
padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device)
|
|
x = torch.cat([x, padding], dim=1)
|
|
|
|
# Initial feature extraction
|
|
x = self.fc1(x)
|
|
x = F.relu(self.ln1(x)) # LayerNorm works with any batch size
|
|
x = self.dropout1(x)
|
|
|
|
# Reshape for LSTM
|
|
x_lstm = x.unsqueeze(1) if x.dim() == 2 else x
|
|
|
|
# Process through LSTM
|
|
lstm_out, _ = self.lstm(x_lstm)
|
|
lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1]
|
|
|
|
# Process through transformer for more complex patterns
|
|
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
|
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
|
|
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
|
|
|
|
# Combine LSTM and transformer outputs
|
|
x = lstm_out + transformer_out
|
|
|
|
# Final layers
|
|
x = self.fc2(x)
|
|
x = F.relu(self.ln2(x)) # LayerNorm works with any batch size
|
|
x = self.dropout2(x)
|
|
x = F.relu(self.fc3(x))
|
|
|
|
# Dueling architecture
|
|
value = self.value_stream(x)
|
|
advantages = self.advantage_stream(x)
|
|
qvals = value + (advantages - advantages.mean(dim=1, keepdim=True))
|
|
|
|
return qvals
|
|
|
|
class PricePredictionModel(nn.Module):
|
|
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
|
|
super(PricePredictionModel, self).__init__()
|
|
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
|
self.scaler = MinMaxScaler(feature_range=(0, 1))
|
|
self.is_fitted = False
|
|
|
|
def forward(self, x):
|
|
# x shape: [batch_size, seq_len, 1]
|
|
lstm_out, _ = self.lstm(x)
|
|
# Use the last time step output
|
|
predictions = self.fc(lstm_out[:, -1, :])
|
|
return predictions
|
|
|
|
def preprocess(self, data):
|
|
# Reshape data for scaler
|
|
data_reshaped = np.array(data).reshape(-1, 1)
|
|
|
|
# Fit scaler if not already fitted
|
|
if not self.is_fitted:
|
|
self.scaler.fit(data_reshaped)
|
|
self.is_fitted = True
|
|
|
|
# Transform data
|
|
scaled_data = self.scaler.transform(data_reshaped)
|
|
return scaled_data
|
|
|
|
def postprocess(self, scaled_predictions):
|
|
# Inverse transform to get actual price values
|
|
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
|
|
|
|
def predict_next_candles(self, price_history, num_candles=5):
|
|
if len(price_history) < 30: # Need enough history
|
|
return np.zeros(num_candles)
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequence
|
|
sequence = scaled_data[-30:].reshape(1, 30, 1)
|
|
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
|
|
|
|
# Get predictions
|
|
with torch.no_grad():
|
|
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
|
|
|
|
# Postprocess predictions
|
|
predictions = self.postprocess(scaled_predictions)
|
|
return predictions
|
|
|
|
def train_on_new_data(self, price_history, optimizer, epochs=10):
|
|
if len(price_history) < 35: # Need enough history for training
|
|
return 0.0
|
|
|
|
# Preprocess data
|
|
scaled_data = self.preprocess(price_history)
|
|
|
|
# Create sequences and targets
|
|
sequences = []
|
|
targets = []
|
|
|
|
for i in range(len(scaled_data) - 35):
|
|
# Sequence: 30 time steps
|
|
seq = scaled_data[i:i+30]
|
|
# Target: next 5 time steps
|
|
target = scaled_data[i+30:i+35].flatten()
|
|
|
|
sequences.append(seq)
|
|
targets.append(target)
|
|
|
|
if not sequences: # If no sequences were created
|
|
return 0.0
|
|
|
|
# Convert to tensors
|
|
sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device)
|
|
targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device)
|
|
|
|
# Training loop
|
|
total_loss = 0
|
|
for _ in range(epochs):
|
|
# Forward pass
|
|
predictions = self(sequences_tensor)
|
|
|
|
# Calculate loss
|
|
loss = F.mse_loss(predictions, targets_tensor)
|
|
|
|
# Backward pass and optimize
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
return total_loss / epochs
|
|
|
|
class TradingEnvironment:
|
|
"""Trading environment for reinforcement learning with enhanced features"""
|
|
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
|
|
"""Initialize trading environment
|
|
|
|
Args:
|
|
initial_balance: Starting account balance
|
|
window_size: Number of candles in the state window
|
|
demo: Whether to run in demo mode
|
|
"""
|
|
self.initial_balance = initial_balance
|
|
self.balance = initial_balance
|
|
self.window_size = window_size
|
|
self.demo = demo
|
|
self.data = []
|
|
self.position = 'flat' # 'flat', 'long', or 'short'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
self.trades = []
|
|
self.win_count = 0
|
|
self.loss_count = 0
|
|
self.total_pnl = 0.0
|
|
self.episode_pnl = 0.0
|
|
self.peak_balance = initial_balance
|
|
self.max_drawdown = 0.0
|
|
self.current_step = 0
|
|
self.current_price = 0
|
|
|
|
# Risk management parameters (adjusted for more aggressive trading)
|
|
self.stop_loss_pct = STOP_LOSS_PERCENT * 0.8 # Tighter stop loss (80% of original)
|
|
self.take_profit_pct = TAKE_PROFIT_PERCENT * 1.5 # Higher take profit (150% of original)
|
|
self.trailing_stop_activated = False
|
|
self.trailing_stop_distance = 0
|
|
self.max_position_size_pct = 0.8 # Use up to 80% of balance for position size
|
|
|
|
# For tracking signals for visualization
|
|
self.trade_signals = []
|
|
|
|
# Initialize features
|
|
self.features = {
|
|
'price': [],
|
|
'volume': [],
|
|
'rsi': [],
|
|
'macd': [],
|
|
'macd_signal': [],
|
|
'macd_hist': [],
|
|
'bollinger_upper': [],
|
|
'bollinger_mid': [],
|
|
'bollinger_lower': [],
|
|
'stoch_k': [],
|
|
'stoch_d': [],
|
|
'ema_9': [],
|
|
'ema_21': [],
|
|
'atr': []
|
|
}
|
|
|
|
# Initialize price predictor
|
|
self.price_predictor = None
|
|
self.predicted_prices = np.array([])
|
|
|
|
# Initialize optimal trade tracking
|
|
self.optimal_bottoms = []
|
|
self.optimal_tops = []
|
|
self.optimal_signals = np.array([])
|
|
|
|
# Add these new attributes
|
|
self.leverage = MAX_LEVERAGE
|
|
self.futures_symbol = "ETH_USDT" # Example futures symbol
|
|
self.position_mode = "hedge" # For simultaneous long/short positions
|
|
self.margin_mode = "cross" # Cross margin mode
|
|
|
|
def reset(self):
|
|
"""Reset the environment to initial state"""
|
|
self.balance = self.initial_balance
|
|
self.position = 'flat'
|
|
self.position_size = 0
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
self.trades = []
|
|
self.win_count = 0
|
|
self.loss_count = 0
|
|
self.episode_pnl = 0.0
|
|
self.peak_balance = self.initial_balance
|
|
self.max_drawdown = 0.0
|
|
self.current_step = 0
|
|
|
|
# Keep data but reset current position
|
|
if len(self.data) > self.window_size:
|
|
self.current_step = self.window_size
|
|
self.current_price = self.data[self.current_step]['close']
|
|
|
|
# Reset trade signals
|
|
self.trade_signals = []
|
|
|
|
return self.get_state()
|
|
|
|
def add_data(self, candle):
|
|
"""Add a new candle to the data"""
|
|
self.data.append(candle)
|
|
self._update_features()
|
|
self.current_price = candle['close']
|
|
|
|
def _initialize_features(self):
|
|
"""Initialize technical indicators and features"""
|
|
if len(self.data) < 30:
|
|
return
|
|
|
|
# Convert data to pandas DataFrame for easier calculation
|
|
df = pd.DataFrame(self.data)
|
|
|
|
# Basic price and volume
|
|
self.features['price'] = df['close'].values
|
|
self.features['volume'] = df['volume'].values
|
|
|
|
# Calculate RSI (14 periods)
|
|
delta = df['close'].diff()
|
|
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
|
|
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
|
|
rs = gain / loss
|
|
self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values
|
|
|
|
# Calculate MACD
|
|
ema12 = df['close'].ewm(span=12, adjust=False).mean()
|
|
ema26 = df['close'].ewm(span=26, adjust=False).mean()
|
|
macd = ema12 - ema26
|
|
signal = macd.ewm(span=9, adjust=False).mean()
|
|
self.features['macd'] = macd.values
|
|
self.features['macd_signal'] = signal.values
|
|
self.features['macd_hist'] = (macd - signal).values
|
|
|
|
# Calculate Bollinger Bands
|
|
sma20 = df['close'].rolling(window=20).mean()
|
|
std20 = df['close'].rolling(window=20).std()
|
|
self.features['bollinger_upper'] = (sma20 + 2 * std20).values
|
|
self.features['bollinger_mid'] = sma20.values
|
|
self.features['bollinger_lower'] = (sma20 - 2 * std20).values
|
|
|
|
# Calculate Stochastic Oscillator
|
|
low_14 = df['low'].rolling(window=14).min()
|
|
high_14 = df['high'].rolling(window=14).max()
|
|
k = 100 * ((df['close'] - low_14) / (high_14 - low_14))
|
|
self.features['stoch_k'] = k.values
|
|
self.features['stoch_d'] = k.rolling(window=3).mean().values
|
|
|
|
# Calculate EMAs
|
|
self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values
|
|
self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values
|
|
|
|
# Calculate ATR
|
|
high_low = df['high'] - df['low']
|
|
high_close = (df['high'] - df['close'].shift()).abs()
|
|
low_close = (df['low'] - df['close'].shift()).abs()
|
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values
|
|
|
|
def _update_features(self):
|
|
"""Update technical indicators with new data"""
|
|
self._initialize_features() # Recalculate all features
|
|
|
|
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""Fetch initial historical data for the environment"""
|
|
try:
|
|
logger.info(f"Fetching initial data for {symbol}")
|
|
|
|
# Use the refactored fetch method
|
|
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
|
|
|
|
# Update environment with fetched data
|
|
if data:
|
|
self.data = data
|
|
self._initialize_features()
|
|
logger.info(f"Initialized environment with {len(data)} candles")
|
|
else:
|
|
logger.warning("No initial data received")
|
|
|
|
return len(data) > 0
|
|
except Exception as e:
|
|
logger.error(f"Error fetching initial data: {e}")
|
|
return False
|
|
|
|
def step(self, action):
|
|
"""Take an action in the environment and return the next state, reward, and done flag"""
|
|
# Check if we have enough data
|
|
if self.current_step >= len(self.data) - 1:
|
|
# We've reached the end of data
|
|
done = True
|
|
next_state = self.get_state()
|
|
info = {
|
|
'action': 'none',
|
|
'price': self.current_price,
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'pnl': self.total_pnl
|
|
}
|
|
return next_state, 0, done, info
|
|
|
|
# Adapt trading parameters to current market conditions
|
|
self.adapt_trading_parameters_to_market()
|
|
|
|
# Store current price before taking action
|
|
self.current_price = self.data[self.current_step]['close']
|
|
|
|
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
|
reward = self.calculate_reward(action)
|
|
|
|
# Record trade signal for visualization
|
|
if action > 0: # If not HOLD
|
|
signal_type = None
|
|
if action == 1: # BUY/LONG
|
|
signal_type = 'buy'
|
|
elif action == 2: # SELL/SHORT
|
|
signal_type = 'sell'
|
|
elif action == 3: # CLOSE
|
|
if self.position == 'long':
|
|
signal_type = 'close_long'
|
|
elif self.position == 'short':
|
|
signal_type = 'close_short'
|
|
|
|
if signal_type:
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.current_price,
|
|
'type': signal_type,
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Check for stop loss / take profit hits
|
|
self.check_sl_tp()
|
|
|
|
# Move to next step
|
|
self.current_step += 1
|
|
done = self.current_step >= len(self.data) - 1
|
|
|
|
# Get new state
|
|
next_state = self.get_state()
|
|
|
|
# Create info dictionary
|
|
info = {
|
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
|
'price': self.current_price,
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'pnl': self.total_pnl
|
|
}
|
|
|
|
return next_state, reward, done, info
|
|
|
|
def check_sl_tp(self):
|
|
"""Check if stop loss or take profit has been hit"""
|
|
if self.position == 'flat':
|
|
return
|
|
|
|
if self.position == 'long':
|
|
# Check stop loss
|
|
if self.current_price <= self.stop_loss:
|
|
# Stop loss hit
|
|
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.stop_loss,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'stop_loss'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.stop_loss,
|
|
'type': 'stop_loss_long',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Check take profit
|
|
elif self.current_price >= self.take_profit:
|
|
# Take profit hit
|
|
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.take_profit,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'take_profit'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.win_count += 1
|
|
|
|
logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.take_profit,
|
|
'type': 'take_profit_long',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
elif self.position == 'short':
|
|
# Check stop loss
|
|
if self.current_price >= self.stop_loss:
|
|
# Stop loss hit
|
|
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.stop_loss,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'stop_loss'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.stop_loss,
|
|
'type': 'stop_loss_short',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Check take profit
|
|
elif self.current_price <= self.take_profit:
|
|
# Take profit hit
|
|
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.take_profit,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': self.current_step - self.entry_index,
|
|
'market_direction': self.get_market_direction(),
|
|
'reason': 'take_profit'
|
|
})
|
|
|
|
# Update win/loss count
|
|
self.win_count += 1
|
|
|
|
logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Record signal for visualization
|
|
self.trade_signals.append({
|
|
'timestamp': self.data[self.current_step]['timestamp'],
|
|
'price': self.take_profit,
|
|
'type': 'take_profit_short',
|
|
'balance': self.balance,
|
|
'pnl': self.total_pnl
|
|
})
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.entry_index = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
def get_state(self):
|
|
"""Create state representation for the agent with enhanced features"""
|
|
# Ensure we have enough data
|
|
if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0:
|
|
# Return zeros if not enough data
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Create a normalized state vector with recent price action and indicators
|
|
state_components = []
|
|
|
|
# Safely get the latest price
|
|
try:
|
|
latest_price = self.features['price'][-1]
|
|
except IndexError:
|
|
# If we can't get the latest price, return zeros
|
|
return np.zeros(STATE_SIZE)
|
|
|
|
# Safely get price features
|
|
try:
|
|
# Price features (normalize recent prices by the latest price)
|
|
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
|
state_components.append(price_features)
|
|
except (IndexError, ZeroDivisionError):
|
|
# If we can't get price features, use zeros
|
|
state_components.append(np.zeros(10))
|
|
|
|
# Safely get volume features
|
|
try:
|
|
# Volume features (normalize by max volume)
|
|
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
|
vol_features = np.array(self.features['volume'][-5:]) / max_vol
|
|
state_components.append(vol_features)
|
|
except (IndexError, ZeroDivisionError):
|
|
# If we can't get volume features, use zeros
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Technical indicators
|
|
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
|
|
state_components.append(rsi)
|
|
|
|
# MACD (normalize)
|
|
macd_vals = np.array(self.features['macd'][-3:])
|
|
macd_signal = np.array(self.features['macd_signal'][-3:])
|
|
macd_hist = np.array(self.features['macd_hist'][-3:])
|
|
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
|
|
macd_norm = macd_vals / macd_scale
|
|
macd_signal_norm = macd_signal / macd_scale
|
|
macd_hist_norm = macd_hist / macd_scale
|
|
|
|
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
|
|
|
|
# Bollinger position (where is price relative to bands)
|
|
bb_upper = np.array(self.features['bollinger_upper'][-3:])
|
|
bb_lower = np.array(self.features['bollinger_lower'][-3:])
|
|
bb_mid = np.array(self.features['bollinger_mid'][-3:])
|
|
price = np.array(self.features['price'][-3:])
|
|
|
|
# Calculate position of price within Bollinger Bands (0 to 1)
|
|
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
|
state_components.append(np.array(bb_pos))
|
|
|
|
# Stochastic oscillator
|
|
state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0)
|
|
state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0)
|
|
|
|
# Add predicted prices (if available)
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Normalize predictions relative to current price
|
|
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
|
|
state_components.append(pred_norm)
|
|
else:
|
|
# Add zeros if no predictions
|
|
state_components.append(np.zeros(3))
|
|
|
|
# Add extrema signals (if available)
|
|
if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0:
|
|
# Get recent signals
|
|
idx = len(self.optimal_signals) - 5
|
|
if idx < 0:
|
|
idx = 0
|
|
recent_signals = self.optimal_signals[idx:idx+5]
|
|
# Pad if needed
|
|
if len(recent_signals) < 5:
|
|
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
|
|
state_components.append(recent_signals)
|
|
else:
|
|
# Add zeros if no signals
|
|
state_components.append(np.zeros(5))
|
|
|
|
# Position info
|
|
position_info = np.zeros(5)
|
|
if self.position == 'long':
|
|
position_info[0] = 1.0 # Position is long
|
|
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
elif self.position == 'short':
|
|
position_info[0] = -1.0 # Position is short
|
|
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
|
|
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
|
|
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
|
|
|
state_components.append(position_info)
|
|
|
|
# NEW FEATURES START HERE
|
|
|
|
# 1. Price momentum features (rate of change over different periods)
|
|
if len(self.features['price']) >= 20:
|
|
roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0
|
|
roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0
|
|
roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0
|
|
momentum_features = np.array([roc_5, roc_10, roc_20])
|
|
state_components.append(momentum_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 2. Volatility features
|
|
if len(self.features['price']) >= 20:
|
|
# Calculate price returns
|
|
returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1]
|
|
# Calculate volatility (standard deviation of returns)
|
|
volatility = np.std(returns)
|
|
# Calculate normalized high-low range
|
|
high_low_range = np.mean([
|
|
(self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close']
|
|
for i in range(max(0, len(self.data)-5), len(self.data))
|
|
]) if len(self.data) > 0 else 0
|
|
# ATR normalized by price
|
|
atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0
|
|
|
|
volatility_features = np.array([volatility, high_low_range, atr_norm])
|
|
state_components.append(volatility_features)
|
|
else:
|
|
state_components.append(np.zeros(3))
|
|
|
|
# 3. Market regime features
|
|
if len(self.features['price']) >= 50:
|
|
# Trend strength (ADX-like measure)
|
|
ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price
|
|
ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price
|
|
trend_strength = abs(ema9 - ema21) / ema21
|
|
|
|
# Detect if in range or trending
|
|
is_range_bound = 1.0 if self.is_uncertain_market() else 0.0
|
|
is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0
|
|
|
|
# Detect if near support/resistance
|
|
near_support = 1.0 if self.is_near_support() else 0.0
|
|
near_resistance = 1.0 if self.is_near_resistance() else 0.0
|
|
|
|
market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance])
|
|
state_components.append(market_regime)
|
|
else:
|
|
state_components.append(np.zeros(5))
|
|
|
|
# 4. Trade history features
|
|
if len(self.trades) > 0:
|
|
# Recent win/loss ratio
|
|
recent_trades = self.trades[-min(10, len(self.trades)):]
|
|
win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades)
|
|
|
|
# Average profit/loss
|
|
avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0
|
|
avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0
|
|
|
|
# Normalize by balance
|
|
avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0
|
|
avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0
|
|
|
|
# Last trade result
|
|
last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0
|
|
|
|
trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl])
|
|
state_components.append(trade_history)
|
|
else:
|
|
state_components.append(np.zeros(4))
|
|
|
|
# Combine all features
|
|
state = np.concatenate([comp.flatten() for comp in state_components])
|
|
|
|
# Replace any NaN or infinite values
|
|
state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# Ensure the state has the correct size
|
|
if len(state) != STATE_SIZE:
|
|
logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}")
|
|
# Pad or truncate to match expected size
|
|
if len(state) < STATE_SIZE:
|
|
state = np.pad(state, (0, STATE_SIZE - len(state)))
|
|
else:
|
|
state = state[:STATE_SIZE]
|
|
|
|
return state
|
|
|
|
def get_expanded_state_size(self):
|
|
"""Calculate the size of the expanded state representation"""
|
|
# Create a dummy state to get its size
|
|
state = self.get_state()
|
|
return len(state)
|
|
|
|
async def expand_model_with_new_features(agent, env):
|
|
"""Expand the model to handle new features without retraining from scratch"""
|
|
# Get the new state size
|
|
new_state_size = env.get_expanded_state_size()
|
|
|
|
# Only expand if the new state size is larger
|
|
if new_state_size > agent.state_size:
|
|
logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})")
|
|
|
|
# Expand the model
|
|
success = agent.expand_model(
|
|
new_state_size=new_state_size,
|
|
new_hidden_size=512, # Increase hidden size for more capacity
|
|
new_lstm_layers=3, # More layers for deeper patterns
|
|
new_attention_heads=8 # More attention heads for complex relationships
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"Model successfully expanded to handle {new_state_size} features")
|
|
return True
|
|
else:
|
|
logger.error("Failed to expand model")
|
|
return False
|
|
else:
|
|
logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient")
|
|
return True
|
|
|
|
|
|
def calculate_reward(self, action):
|
|
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
|
reward = 0
|
|
|
|
# Base reward for actions
|
|
if action == 0: # HOLD
|
|
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
|
|
|
|
elif action == 1: # BUY/LONG
|
|
if self.position == 'flat':
|
|
# Opening a long position
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
# Use the adjusted risk parameters
|
|
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal buy point (bottom)
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
|
reward += 3.0 # Increased bonus for buying at a bottom
|
|
|
|
# Check for volume spike (indicating potential big movement)
|
|
if len(self.features['volume']) > 5:
|
|
avg_volume = np.mean(self.features['volume'][-5:-1])
|
|
current_volume = self.features['volume'][-1]
|
|
if current_volume > avg_volume * 1.5:
|
|
reward += 2.0 # Bonus for entering during high volume
|
|
|
|
# Check for price action signals
|
|
if self.features['rsi'][-1] < 30: # Oversold condition
|
|
reward += 1.5 # Bonus for buying at oversold levels
|
|
|
|
# Check if we're buying in a clear uptrend (good)
|
|
if self.is_uptrend():
|
|
reward += 1.0 # Bonus for buying in uptrend
|
|
elif self.is_downtrend():
|
|
reward -= 0.25 # Reduced penalty for buying in downtrend
|
|
else:
|
|
reward += 0.2 # Small reward for opening a position
|
|
|
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif self.position == 'short':
|
|
# Close short and open long
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
trade_duration = len(self.features['price']) - self.entry_index
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar,
|
|
'duration': trade_duration,
|
|
'market_direction': self.get_market_direction()
|
|
})
|
|
|
|
# Reward based on PnL with stronger penalties for losses
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
# Stronger penalty for losses, scaled by the size of the loss
|
|
loss_penalty = 1.0 + abs(pnl_dollar) / 5
|
|
reward -= loss_penalty
|
|
self.loss_count += 1
|
|
|
|
# Extra penalty for closing a losing trade too quickly
|
|
if trade_duration < 5:
|
|
reward -= 0.5 # Penalty for very short losing trades
|
|
|
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Now open long
|
|
self.position = 'long'
|
|
self.entry_price = self.current_price
|
|
self.entry_index = len(self.features['price']) - 1
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal buy point
|
|
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
|
|
reward += 2.0 # Bonus for buying at a bottom
|
|
|
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif action == 2: # SELL/SHORT
|
|
if self.position == 'flat':
|
|
# Opening a short position
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
# Use the adjusted risk parameters
|
|
self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 - self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal sell point (top)
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
|
reward += 3.0 # Increased bonus for selling at a top
|
|
|
|
# Check for volume spike
|
|
if len(self.features['volume']) > 5:
|
|
avg_volume = np.mean(self.features['volume'][-5:-1])
|
|
current_volume = self.features['volume'][-1]
|
|
if current_volume > avg_volume * 1.5:
|
|
reward += 2.0 # Bonus for entering during high volume
|
|
|
|
# Check for price action signals
|
|
if self.features['rsi'][-1] > 70: # Overbought condition
|
|
reward += 1.5 # Bonus for selling at overbought levels
|
|
|
|
# Check if we're selling in a clear downtrend (good)
|
|
if self.is_downtrend():
|
|
reward += 1.0 # Bonus for selling in downtrend
|
|
elif self.is_uptrend():
|
|
reward -= 0.25 # Reduced penalty for selling in uptrend
|
|
else:
|
|
reward += 0.2 # Small reward for opening a position
|
|
|
|
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif self.position == 'long':
|
|
# Close long and open short
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Now open short
|
|
self.position = 'short'
|
|
self.entry_price = self.current_price
|
|
self.position_size = self.calculate_position_size()
|
|
self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100)
|
|
self.take_profit = self.entry_price * (1 - self.take_profit_pct/100)
|
|
|
|
# Check if this is an optimal sell point
|
|
current_idx = len(self.features['price']) - 1
|
|
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
|
reward += 2.0 # Bonus for selling at a top
|
|
|
|
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
|
|
|
elif action == 3: # CLOSE
|
|
if self.position == 'long':
|
|
# Close long position
|
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'long',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
elif self.position == 'short':
|
|
# Close short position
|
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
|
|
|
# Apply fees
|
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
|
|
|
# Update balance
|
|
self.balance += pnl_dollar
|
|
self.total_pnl += pnl_dollar
|
|
self.episode_pnl += pnl_dollar
|
|
|
|
# Update max drawdown
|
|
if self.balance > self.peak_balance:
|
|
self.peak_balance = self.balance
|
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
|
|
|
# Record trade
|
|
self.trades.append({
|
|
'type': 'short',
|
|
'entry': self.entry_price,
|
|
'exit': self.current_price,
|
|
'pnl_percent': pnl_percent,
|
|
'pnl_dollar': pnl_dollar
|
|
})
|
|
|
|
# Reward based on PnL
|
|
if pnl_dollar > 0:
|
|
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
self.win_count += 1
|
|
else:
|
|
reward -= 1.0 # Negative reward for loss
|
|
self.loss_count += 1
|
|
|
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
|
|
|
# Reset position
|
|
self.position = 'flat'
|
|
self.entry_price = 0
|
|
self.position_size = 0
|
|
self.stop_loss = 0
|
|
self.take_profit = 0
|
|
|
|
# Add prediction accuracy component to reward
|
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
|
# Compare the first prediction with actual price
|
|
if len(self.data) > 1:
|
|
actual_price = self.data[-1]['close']
|
|
predicted_price = self.predicted_prices[0]
|
|
prediction_error = abs(predicted_price - actual_price) / actual_price
|
|
|
|
# Reward accurate predictions, penalize bad ones
|
|
if prediction_error < 0.005: # Less than 0.5% error
|
|
reward += 0.5
|
|
elif prediction_error > 0.02: # More than 2% error
|
|
reward -= 0.5
|
|
|
|
return reward
|
|
|
|
def is_downtrend(self):
|
|
"""Check if the market is in a downtrend"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Use EMA to determine trend
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
|
|
# Downtrend if short EMA is below long EMA
|
|
return short_ema < long_ema
|
|
|
|
def is_uptrend(self):
|
|
"""Check if the market is in an uptrend"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Use EMA to determine trend
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
|
|
# Uptrend if short EMA is above long EMA
|
|
return short_ema > long_ema
|
|
|
|
def get_market_direction(self):
|
|
"""Get the current market direction"""
|
|
if self.is_uptrend():
|
|
return "uptrend"
|
|
elif self.is_downtrend():
|
|
return "downtrend"
|
|
else:
|
|
return "sideways"
|
|
|
|
def analyze_trades(self):
|
|
"""Analyze completed trades to identify patterns"""
|
|
if not self.trades:
|
|
return {}
|
|
|
|
analysis = {
|
|
'total_trades': len(self.trades),
|
|
'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0),
|
|
'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0),
|
|
'avg_win': 0,
|
|
'avg_loss': 0,
|
|
'avg_duration': 0,
|
|
'uptrend_win_rate': 0,
|
|
'downtrend_win_rate': 0,
|
|
'sideways_win_rate': 0
|
|
}
|
|
|
|
# Calculate averages
|
|
wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0]
|
|
losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0]
|
|
durations = [t.get('duration', 0) for t in self.trades]
|
|
|
|
analysis['avg_win'] = sum(wins) / len(wins) if wins else 0
|
|
analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0
|
|
analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0
|
|
|
|
# Calculate win rates by market direction
|
|
for direction in ['uptrend', 'downtrend', 'sideways']:
|
|
direction_trades = [t for t in self.trades if t.get('market_direction') == direction]
|
|
if direction_trades:
|
|
wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0)
|
|
analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100
|
|
|
|
return analysis
|
|
|
|
def initialize_price_predictor(self, device="cpu"):
|
|
"""Initialize the price prediction model"""
|
|
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
|
self.price_predictor.to(device)
|
|
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
|
self.predicted_prices = np.array([])
|
|
|
|
def train_price_predictor(self):
|
|
"""Train the price prediction model on recent data"""
|
|
if len(self.features['price']) < 35:
|
|
return 0.0
|
|
|
|
# Get price history
|
|
price_history = self.features['price']
|
|
|
|
# Train the model
|
|
loss = self.price_predictor.train_on_new_data(
|
|
price_history,
|
|
self.price_predictor_optimizer,
|
|
epochs=5
|
|
)
|
|
|
|
return loss
|
|
|
|
def update_price_predictions(self):
|
|
"""Update price predictions"""
|
|
if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
|
|
self.predicted_prices = np.array([])
|
|
return
|
|
|
|
# Get price history
|
|
price_history = self.features['price']
|
|
|
|
try:
|
|
# Get predictions
|
|
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
|
except Exception as e:
|
|
logger.warning(f"Error updating predictions: {e}")
|
|
self.predicted_prices = np.array([])
|
|
|
|
def identify_optimal_trades(self):
|
|
"""Identify optimal entry and exit points based on local extrema"""
|
|
if len(self.features['price']) < 20:
|
|
return
|
|
|
|
# Find local bottoms and tops
|
|
bottoms, tops = find_local_extrema(self.features['price'], window=5)
|
|
|
|
# Store optimal trade points
|
|
self.optimal_bottoms = bottoms # Buy points
|
|
self.optimal_tops = tops # Sell points
|
|
|
|
# Create optimal trade signals
|
|
self.optimal_signals = np.zeros(len(self.features['price']))
|
|
for i in bottoms:
|
|
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
|
|
self.optimal_signals[i] = 1 # Buy signal
|
|
for i in tops:
|
|
if 0 <= i < len(self.optimal_signals): # Ensure index is valid
|
|
self.optimal_signals[i] = -1 # Sell signal
|
|
|
|
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
|
|
|
|
def calculate_position_size(self):
|
|
"""Calculate position size based on current balance and risk parameters
|
|
|
|
Returns:
|
|
float: Position size in quote currency
|
|
"""
|
|
# More aggressive position sizing
|
|
risk_amount = self.balance * (self.max_position_size_pct * random.uniform(0.7, 1.0))
|
|
|
|
# In futures trading, adjust for leverage
|
|
if hasattr(self, 'leverage') and self.leverage > 1:
|
|
risk_amount = min(risk_amount * self.leverage, self.balance * 10) # Limit max risk
|
|
|
|
return risk_amount
|
|
|
|
def calculate_fees(self, position_size):
|
|
"""Calculate trading fees for a given position size"""
|
|
# Typical fee rate for crypto exchanges (0.1%)
|
|
fee_rate = 0.001
|
|
|
|
# Calculate fee
|
|
fee = position_size * fee_rate
|
|
|
|
return fee
|
|
|
|
def is_uncertain_market(self):
|
|
"""Check if the market is in an uncertain/sideways state"""
|
|
if len(self.features['price']) < 20:
|
|
return True
|
|
|
|
# Check if price is within a narrow range
|
|
recent_prices = self.features['price'][-20:]
|
|
price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices)
|
|
|
|
# Check if EMAs are close to each other
|
|
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
|
|
short_ema = self.features['ema_9'][-1]
|
|
long_ema = self.features['ema_21'][-1]
|
|
ema_diff = abs(short_ema - long_ema) / long_ema
|
|
|
|
# Return True if price range is small and EMAs are close
|
|
return price_range < 0.02 and ema_diff < 0.005
|
|
|
|
return price_range < 0.015 # Very narrow range
|
|
|
|
def is_near_support(self):
|
|
"""Check if current price is near a support level"""
|
|
if not hasattr(self, 'features') or len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Find recent lows
|
|
prices = self.features['price'][-30:]
|
|
lows = []
|
|
|
|
for i in range(1, len(prices)-1):
|
|
if prices[i] < prices[i-1] and prices[i] < prices[i+1]:
|
|
lows.append(prices[i])
|
|
|
|
if not lows:
|
|
return False
|
|
|
|
# Check if current price is near any of these lows
|
|
current_price = self.current_price
|
|
for low in lows:
|
|
if abs(current_price - low) / low < 0.01: # Within 1% of a recent low
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_near_resistance(self):
|
|
"""Check if current price is near a resistance level"""
|
|
if not hasattr(self, 'features') or len(self.features['price']) < 30:
|
|
return False
|
|
|
|
# Find recent highs
|
|
prices = self.features['price'][-30:]
|
|
highs = []
|
|
|
|
for i in range(1, len(prices)-1):
|
|
if prices[i] > prices[i-1] and prices[i] > prices[i+1]:
|
|
highs.append(prices[i])
|
|
|
|
if not highs:
|
|
return False
|
|
|
|
# Check if current price is near any of these highs
|
|
current_price = self.current_price
|
|
for high in highs:
|
|
if abs(current_price - high) / high < 0.01: # Within 1% of a recent high
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_market_turning(self):
|
|
"""Check if the market is potentially changing direction"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Check for divergence between price and momentum indicators
|
|
if len(self.features['rsi']) > 5:
|
|
# Price making higher highs but RSI making lower highs (bearish divergence)
|
|
price_trend = self.features['price'][-1] > self.features['price'][-5]
|
|
rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5]
|
|
|
|
if price_trend != rsi_trend:
|
|
return True
|
|
|
|
# Check for EMA crossover
|
|
if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1:
|
|
short_ema_prev = self.features['ema_9'][-2]
|
|
long_ema_prev = self.features['ema_21'][-2]
|
|
short_ema_curr = self.features['ema_9'][-1]
|
|
long_ema_curr = self.features['ema_21'][-1]
|
|
|
|
# Check if EMAs just crossed
|
|
if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \
|
|
(short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr):
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_market_against_position(self, position_type):
|
|
"""Check if market conditions have turned against the current position"""
|
|
if position_type == 'long':
|
|
# For long positions, check if market has turned bearish
|
|
return self.is_downtrend() and not self.is_near_support()
|
|
elif position_type == 'short':
|
|
# For short positions, check if market has turned bullish
|
|
return self.is_uptrend() and not self.is_near_resistance()
|
|
|
|
return False
|
|
|
|
def is_near_optimal_exit(self, position_type):
|
|
"""Check if current price is near an optimal exit point for the position"""
|
|
current_idx = len(self.features['price']) - 1
|
|
|
|
if position_type == 'long' and hasattr(self, 'optimal_tops'):
|
|
# For long positions, optimal exit is near tops
|
|
for top_idx in self.optimal_tops:
|
|
if abs(current_idx - top_idx) < 3: # Within 3 candles of a top
|
|
return True
|
|
elif position_type == 'short' and hasattr(self, 'optimal_bottoms'):
|
|
# For short positions, optimal exit is near bottoms
|
|
for bottom_idx in self.optimal_bottoms:
|
|
if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom
|
|
return True
|
|
|
|
return False
|
|
|
|
def calculate_future_profit_potential(self, position_type, lookahead=20):
|
|
"""
|
|
Calculate potential profit if position is held for a certain period
|
|
This is used for retrospective backtesting rewards
|
|
|
|
Args:
|
|
position_type: 'long' or 'short'
|
|
lookahead: Number of candles to look ahead
|
|
|
|
Returns:
|
|
Potential profit percentage
|
|
"""
|
|
if len(self.data) <= 1 or self.current_step >= len(self.data):
|
|
return 0
|
|
|
|
# Get current price
|
|
current_price = self.current_price
|
|
|
|
# Get future prices (if available in historical data)
|
|
future_prices = []
|
|
current_idx = self.current_step
|
|
|
|
# Safely get future prices
|
|
for i in range(1, min(lookahead + 1, len(self.data) - current_idx)):
|
|
if current_idx + i < len(self.data):
|
|
future_prices.append(self.data[current_idx + i]['close'])
|
|
|
|
if not future_prices:
|
|
return 0
|
|
|
|
# Calculate potential profit
|
|
if position_type == 'long':
|
|
# For long positions, find the maximum price in the future
|
|
max_future_price = max(future_prices)
|
|
potential_profit = (max_future_price - current_price) / current_price * 100
|
|
else: # short
|
|
# For short positions, find the minimum price in the future
|
|
min_future_price = min(future_prices)
|
|
potential_profit = (current_price - min_future_price) / current_price * 100
|
|
|
|
return potential_profit
|
|
|
|
async def initialize_futures(self, exchange):
|
|
"""Initialize futures trading parameters"""
|
|
if not self.demo:
|
|
try:
|
|
# Set up futures trading parameters
|
|
await exchange.set_position_mode(True) # Hedge mode
|
|
await exchange.set_margin_mode("cross", symbol=self.futures_symbol)
|
|
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
|
|
logger.info(f"Futures initialized with {self.leverage}x leverage")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize futures trading: {str(e)}")
|
|
logger.info("Falling back to demo mode for safety")
|
|
demo = True
|
|
|
|
async def execute_real_trade(self, exchange, action, current_price):
|
|
"""Execute real futures trade on MEXC"""
|
|
try:
|
|
position_size = self.calculate_position_size()
|
|
|
|
if action == 1: # Open long
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='buy',
|
|
amount=position_size,
|
|
params={'positionSide': 'LONG'}
|
|
)
|
|
logger.info(f"Opened LONG position: {order}")
|
|
|
|
elif action == 2: # Open short
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='sell',
|
|
amount=position_size,
|
|
params={'positionSide': 'SHORT'}
|
|
)
|
|
logger.info(f"Opened SHORT position: {order}")
|
|
|
|
elif action == 3: # Close position
|
|
position_side = 'LONG' if self.position == 'long' else 'SHORT'
|
|
order = await exchange.create_order(
|
|
symbol=self.futures_symbol,
|
|
type='market',
|
|
side='sell' if position_side == 'LONG' else 'buy',
|
|
amount=self.position_size,
|
|
params={'positionSide': position_side}
|
|
)
|
|
logger.info(f"Closed {position_side} position: {order}")
|
|
|
|
return order
|
|
except Exception as e:
|
|
logger.error(f"Trade execution failed: {e}")
|
|
return None
|
|
|
|
def is_volatile_market(self):
|
|
"""Detect if the market is currently in a volatile state with significant price movements
|
|
|
|
Returns:
|
|
bool: True if market is volatile, False otherwise
|
|
"""
|
|
if len(self.features['price']) < 20:
|
|
return False
|
|
|
|
# Calculate recent price volatility
|
|
recent_prices = self.features['price'][-20:]
|
|
returns = np.diff(recent_prices) / recent_prices[:-1]
|
|
volatility = np.std(returns) * 100 # Convert to percentage
|
|
|
|
# Calculate volume increase
|
|
recent_volumes = self.features['volume'][-10:]
|
|
avg_volume_prev = np.mean(self.features['volume'][-20:-10])
|
|
avg_volume_recent = np.mean(recent_volumes)
|
|
volume_increase = avg_volume_recent / avg_volume_prev if avg_volume_prev > 0 else 1.0
|
|
|
|
# Calculate ATR if available
|
|
atr_high = False
|
|
if len(self.features['atr']) > 5:
|
|
recent_atr = self.features['atr'][-1]
|
|
avg_atr = np.mean(self.features['atr'][-20:-1])
|
|
atr_ratio = recent_atr / avg_atr if avg_atr > 0 else 1.0
|
|
atr_high = atr_ratio > 1.5
|
|
|
|
# Check if price moved significantly in either direction recently
|
|
price_range_percent = (max(recent_prices) - min(recent_prices)) / min(recent_prices) * 100
|
|
|
|
# Market is volatile if any of these conditions are met
|
|
volatile = (
|
|
volatility > 0.5 or # High standard deviation of returns
|
|
volume_increase > 1.8 or # Volume spike
|
|
price_range_percent > 1.5 or # Large price range
|
|
atr_high # High ATR relative to average
|
|
)
|
|
|
|
if volatile:
|
|
logger.info(f"Volatile market detected - Volatility: {volatility:.2f}%, Volume increase: {volume_increase:.2f}x, Price range: {price_range_percent:.2f}%")
|
|
|
|
return volatile
|
|
|
|
def adapt_trading_parameters_to_market(self):
|
|
"""Dynamically adjust trading parameters based on market conditions
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
# Check market conditions
|
|
is_volatile = self.is_volatile_market()
|
|
is_trending_up = self.is_uptrend()
|
|
is_trending_down = self.is_downtrend()
|
|
|
|
# Base parameters
|
|
base_stop_loss = STOP_LOSS_PERCENT
|
|
base_take_profit = TAKE_PROFIT_PERCENT
|
|
base_position_size = 0.5 # 50% of max
|
|
|
|
# Adjust based on market conditions
|
|
if is_volatile:
|
|
# In volatile markets, use tighter stops but higher take profits
|
|
self.stop_loss_pct = base_stop_loss * 0.7 # Tighter stop
|
|
self.take_profit_pct = base_take_profit * 1.8 # Higher target
|
|
self.max_position_size_pct = base_position_size * 1.3 # More aggressive sizing
|
|
|
|
elif is_trending_up:
|
|
# In uptrends, use looser stops for longs, tighter for shorts
|
|
if self.position == 'long' or self.position == 'flat':
|
|
self.stop_loss_pct = base_stop_loss * 0.9
|
|
self.take_profit_pct = base_take_profit * 1.6
|
|
self.max_position_size_pct = base_position_size * 1.2
|
|
else:
|
|
# More conservative for shorts in uptrend
|
|
self.stop_loss_pct = base_stop_loss * 0.7
|
|
self.take_profit_pct = base_take_profit * 1.2
|
|
self.max_position_size_pct = base_position_size * 0.8
|
|
|
|
elif is_trending_down:
|
|
# In downtrends, use looser stops for shorts, tighter for longs
|
|
if self.position == 'short' or self.position == 'flat':
|
|
self.stop_loss_pct = base_stop_loss * 0.9
|
|
self.take_profit_pct = base_take_profit * 1.6
|
|
self.max_position_size_pct = base_position_size * 1.2
|
|
else:
|
|
# More conservative for longs in downtrend
|
|
self.stop_loss_pct = base_stop_loss * 0.7
|
|
self.take_profit_pct = base_take_profit * 1.2
|
|
self.max_position_size_pct = base_position_size * 0.8
|
|
else:
|
|
# In sideways/uncertain markets, be more balanced
|
|
self.stop_loss_pct = base_stop_loss * 0.8
|
|
self.take_profit_pct = base_take_profit * 1.3
|
|
self.max_position_size_pct = base_position_size
|
|
|
|
# Log the adaptation
|
|
logger.debug(f"Adapted trading parameters - Stop loss: {self.stop_loss_pct:.2f}%, Take profit: {self.take_profit_pct:.2f}%, Max position size: {self.max_position_size_pct*100:.1f}%")
|
|
|
|
# Ensure GPU usage if available
|
|
def get_device():
|
|
"""Get the best available device (CUDA GPU or CPU)"""
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
# Set up for mixed precision training
|
|
torch.backends.cudnn.benchmark = True
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info("GPU not available, using CPU")
|
|
return device
|
|
|
|
# Update Agent class to use GPU properly
|
|
class Agent:
|
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
|
|
"""Initialize Agent with architecture parameters stored as attributes"""
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
|
|
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
|
|
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
|
|
|
|
# Set device
|
|
self.device = device if device is not None else get_device()
|
|
|
|
# Initialize networks
|
|
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
|
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
# Initialize optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
|
|
|
# Initialize replay memory
|
|
self.memory = ReplayMemory(MEMORY_SIZE)
|
|
|
|
# Initialize exploration parameters
|
|
self.epsilon = EPSILON_START
|
|
self.epsilon_decay = EPSILON_DECAY
|
|
self.epsilon_min = EPSILON_END
|
|
|
|
# Initialize step counter
|
|
self.steps_done = 0
|
|
|
|
# Initialize TensorBoard writer
|
|
self.writer = None
|
|
|
|
# Initialize GradScaler for mixed precision training
|
|
self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None
|
|
|
|
# Rest of the initialization code...
|
|
|
|
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
|
"""Expand the model to handle more features or increase capacity"""
|
|
logger.info(f"Expanding model: {self.state_size} → {new_state_size}, "
|
|
f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}")
|
|
|
|
# Save old weights
|
|
old_state_dict = self.policy_net.state_dict()
|
|
|
|
# Create new larger networks
|
|
new_policy_net = DQN(new_state_size, self.action_size,
|
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
|
new_target_net = DQN(new_state_size, self.action_size,
|
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
|
|
|
# Transfer weights for common layers
|
|
new_state_dict = new_policy_net.state_dict()
|
|
for name, param in old_state_dict.items():
|
|
if name in new_state_dict:
|
|
# If shapes match, copy directly
|
|
if new_state_dict[name].shape == param.shape:
|
|
new_state_dict[name] = param
|
|
# For first layer, copy weights for the original input dimensions
|
|
elif name == "fc1.weight":
|
|
new_state_dict[name][:, :self.state_size] = param
|
|
# For other layers, initialize with a strategy that preserves scale
|
|
else:
|
|
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
|
|
|
|
# Load transferred weights
|
|
new_policy_net.load_state_dict(new_state_dict)
|
|
new_target_net.load_state_dict(new_state_dict)
|
|
|
|
# Replace networks
|
|
self.policy_net = new_policy_net
|
|
self.target_net = new_target_net
|
|
self.target_net.eval()
|
|
|
|
# Update optimizer
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
|
|
|
# Update state size
|
|
self.state_size = new_state_size
|
|
|
|
# Print new model size
|
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
|
logger.info(f"New model size: {total_params:,} parameters")
|
|
|
|
return True
|
|
|
|
def select_action(self, state, training=True):
|
|
sample = random.random()
|
|
|
|
if training:
|
|
# More aggressive epsilon decay for faster exploitation
|
|
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
|
|
np.exp(-1.5 * self.steps_done / EPSILON_DECAY) # Increased decay factor
|
|
self.steps_done += 1
|
|
|
|
# Lower threshold for exploration, especially in live trading
|
|
if not training:
|
|
# In live trading, be much more aggressive with exploitation
|
|
self.epsilon = max(EPSILON_END, self.epsilon * 0.95)
|
|
|
|
if sample > self.epsilon or not training:
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(state).to(self.device)
|
|
action_values = self.policy_net(state_tensor)
|
|
|
|
# Add temperature-based sampling for more aggressive actions
|
|
# when the model is confident (higher action differences)
|
|
if not training: # More aggressive in live trading
|
|
values = action_values.cpu().numpy()
|
|
max_value = np.max(values)
|
|
value_diff = max_value - np.mean(values)
|
|
|
|
# If there's a clear best action, always take it
|
|
if value_diff > 0.5:
|
|
return action_values.max(1)[1].item()
|
|
|
|
return action_values.max(1)[1].item()
|
|
else:
|
|
return random.randrange(self.action_size)
|
|
|
|
def learn(self):
|
|
"""Learn from a batch of experiences"""
|
|
if len(self.memory) < BATCH_SIZE:
|
|
return None
|
|
|
|
try:
|
|
# Sample a batch of experiences
|
|
experiences = self.memory.sample(BATCH_SIZE)
|
|
|
|
# Convert experiences to tensors
|
|
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
|
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
|
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
|
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
|
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
|
|
|
|
# Use mixed precision for forward/backward passes
|
|
if self.device.type == "cuda" and self.scaler is not None:
|
|
with torch.amp.autocast('cuda'):
|
|
# Compute Q values
|
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
|
|
|
# Compute next Q values with target network
|
|
with torch.no_grad():
|
|
next_q_values = self.target_net(next_states).max(1)[0]
|
|
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
|
|
|
|
# Reshape target values to match current_q_values
|
|
target_q_values = target_q_values.unsqueeze(1)
|
|
|
|
# Compute loss
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
|
|
|
# Backward pass with mixed precision
|
|
self.optimizer.zero_grad()
|
|
self.scaler.scale(loss).backward()
|
|
|
|
# Gradient clipping to prevent exploding gradients
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
# Standard precision for CPU
|
|
# Compute Q values
|
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
|
|
|
# Compute next Q values with target network
|
|
with torch.no_grad():
|
|
next_q_values = self.target_net(next_states).max(1)[0]
|
|
target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))
|
|
|
|
# Reshape target values to match current_q_values
|
|
target_q_values = target_q_values.unsqueeze(1)
|
|
|
|
# Compute loss
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
|
|
|
# Backward pass
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# Gradient clipping to prevent exploding gradients
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
|
|
|
self.optimizer.step()
|
|
|
|
# Update steps done
|
|
self.steps_done += 1
|
|
|
|
# Update target network
|
|
if self.steps_done % TARGET_UPDATE == 0:
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
return loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during learning: {e}")
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
return None
|
|
|
|
def update_target_network(self):
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
def save(self, path="models/trading_agent_best_pnl.pt"):
|
|
"""Save the model in a format compatible with PyTorch 2.6+"""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
# Ensure architecture parameters are set
|
|
if not hasattr(self, 'hidden_size'):
|
|
self.hidden_size = 256 # Default value
|
|
logger.warning("Setting default hidden_size=256 for saving")
|
|
|
|
if not hasattr(self, 'lstm_layers'):
|
|
self.lstm_layers = 2 # Default value
|
|
logger.warning("Setting default lstm_layers=2 for saving")
|
|
|
|
if not hasattr(self, 'attention_heads'):
|
|
self.attention_heads = 4 # Default value
|
|
logger.warning("Setting default attention_heads=4 for saving")
|
|
|
|
# Save model state
|
|
checkpoint = {
|
|
'policy_net': self.policy_net.state_dict(),
|
|
'target_net': self.target_net.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'epsilon': self.epsilon,
|
|
'state_size': self.state_size,
|
|
'action_size': self.action_size,
|
|
'hidden_size': self.hidden_size,
|
|
'lstm_layers': self.lstm_layers,
|
|
'attention_heads': self.attention_heads
|
|
}
|
|
|
|
# Save scaler state if it exists
|
|
if hasattr(self, 'scaler') and self.scaler is not None:
|
|
checkpoint['scaler'] = self.scaler.state_dict()
|
|
|
|
# Save with pickle_protocol=4 for better compatibility
|
|
torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4)
|
|
logger.info(f"Model saved to {path}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
def load(self, path="models/trading_agent_best_pnl.pt"):
|
|
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
|
|
try:
|
|
# First try to load with weights_only=False (for models saved with older PyTorch versions)
|
|
try:
|
|
logger.info(f"Attempting to load model with weights_only=False: {path}")
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
logger.info("Model loaded successfully with weights_only=False")
|
|
except Exception as e1:
|
|
logger.warning(f"Failed to load with weights_only=False: {e1}")
|
|
|
|
# Try with safe_globals context manager
|
|
try:
|
|
logger.info("Attempting to load with safe_globals context manager")
|
|
import numpy as np
|
|
from torch.serialization import safe_globals
|
|
|
|
# Add numpy scalar to safe globals
|
|
with safe_globals(['numpy._core.multiarray.scalar']):
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
logger.info("Model loaded successfully with safe_globals")
|
|
except Exception as e2:
|
|
logger.warning(f"Failed to load with safe_globals: {e2}")
|
|
|
|
# Last resort: try with pickle_module=pickle
|
|
logger.info("Attempting to load with pickle_module")
|
|
import pickle
|
|
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
|
|
logger.info("Model loaded successfully with pickle_module")
|
|
|
|
# Load state dictionaries
|
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
|
|
|
# Try to load optimizer state
|
|
try:
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
except Exception as e:
|
|
logger.warning(f"Could not load optimizer state: {e}")
|
|
|
|
# Load epsilon if available
|
|
if 'epsilon' in checkpoint:
|
|
self.epsilon = checkpoint['epsilon']
|
|
|
|
# Load architecture parameters if available
|
|
if 'state_size' in checkpoint:
|
|
self.state_size = checkpoint['state_size']
|
|
if 'action_size' in checkpoint:
|
|
self.action_size = checkpoint['action_size']
|
|
if 'hidden_size' in checkpoint:
|
|
self.hidden_size = checkpoint['hidden_size']
|
|
else:
|
|
# If hidden_size not in checkpoint, infer from model
|
|
try:
|
|
self.hidden_size = self.policy_net.fc1.weight.shape[0]
|
|
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
|
|
except:
|
|
self.hidden_size = 256 # Default value
|
|
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
|
|
|
|
if 'lstm_layers' in checkpoint:
|
|
self.lstm_layers = checkpoint['lstm_layers']
|
|
else:
|
|
self.lstm_layers = 2 # Default value
|
|
|
|
if 'attention_heads' in checkpoint:
|
|
self.attention_heads = checkpoint['attention_heads']
|
|
else:
|
|
self.attention_heads = 4 # Default value
|
|
|
|
logger.info(f"Model loaded successfully from {path}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def add_chart_to_tensorboard(self, env, global_step):
|
|
"""Add trading chart to TensorBoard"""
|
|
try:
|
|
if len(env.data) < 10:
|
|
return
|
|
|
|
# Create chart image
|
|
chart_img = create_candlestick_figure(
|
|
env.data,
|
|
env.trade_signals,
|
|
window_size=100,
|
|
title=f"Trading Chart - Step {global_step}"
|
|
)
|
|
|
|
if chart_img is not None:
|
|
# Convert PIL image to numpy array for TensorBoard
|
|
chart_array = np.array(chart_img)
|
|
# TensorBoard expects [C, H, W] format
|
|
chart_array = np.transpose(chart_array, (2, 0, 1))
|
|
self.writer.add_image('Trading Chart', chart_array, global_step)
|
|
|
|
# Add position information as text
|
|
entry_price = env.entry_price if env.entry_price else 0.00
|
|
position_info = f"""
|
|
**Current Position**: {env.position.upper()}
|
|
**Entry Price**: ${entry_price:.2f}
|
|
**Current Price**: ${env.data[-1]['close']:.2f}
|
|
**Position Size**: ${env.position_size:.2f}
|
|
**Unrealized PnL**: ${env.total_pnl:.2f}
|
|
"""
|
|
self.writer.add_text('Position', position_info, global_step)
|
|
except Exception as e:
|
|
logger.error(f"Error adding chart to TensorBoard: {str(e)}")
|
|
# Continue without visualization rather than crashing
|
|
|
|
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
|
"""Get live price data using websockets"""
|
|
# Connect to MEXC websocket
|
|
uri = "wss://stream.mexc.com/ws"
|
|
|
|
async with websockets.connect(uri) as websocket:
|
|
# Subscribe to kline data
|
|
subscribe_msg = {
|
|
"method": "SUBSCRIPTION",
|
|
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
|
|
}
|
|
await websocket.send(json.dumps(subscribe_msg))
|
|
|
|
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
|
|
|
|
while True:
|
|
try:
|
|
response = await websocket.recv()
|
|
data = json.loads(response)
|
|
|
|
if 'data' in data:
|
|
kline = data['data']
|
|
candle = {
|
|
'timestamp': kline['t'],
|
|
'open': float(kline['o']),
|
|
'high': float(kline['h']),
|
|
'low': float(kline['l']),
|
|
'close': float(kline['c']),
|
|
'volume': float(kline['v'])
|
|
}
|
|
yield candle
|
|
|
|
except Exception as e:
|
|
logger.error(f"Websocket error: {e}")
|
|
# Try to reconnect
|
|
await asyncio.sleep(5)
|
|
break
|
|
|
|
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
|
|
"""Train the agent using historical and live data with GPU acceleration"""
|
|
# Initialize statistics tracking
|
|
stats = {
|
|
'episode_rewards': [],
|
|
'episode_lengths': [],
|
|
'balances': [],
|
|
'win_rates': [],
|
|
'episode_pnls': [],
|
|
'cumulative_pnl': [],
|
|
'drawdowns': [],
|
|
'prediction_accuracy': [],
|
|
'trade_analysis': []
|
|
}
|
|
|
|
# Track best models
|
|
best_reward = float('-inf')
|
|
best_pnl = float('-inf')
|
|
|
|
# Initialize TensorBoard writer if not already initialized
|
|
if not hasattr(agent, 'writer') or agent.writer is None:
|
|
agent.writer = SummaryWriter('runs/training')
|
|
|
|
# Training loop
|
|
for episode in range(num_episodes):
|
|
try:
|
|
# Reset environment
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
prediction_loss = 0
|
|
|
|
# Episode loop
|
|
for step in range(max_steps_per_episode):
|
|
# Select action
|
|
action = agent.select_action(state)
|
|
|
|
# Take action
|
|
try:
|
|
next_state, reward, done, info = env.step(action)
|
|
except Exception as e:
|
|
logger.error(f"Error in step function: {e}")
|
|
break
|
|
|
|
# Store transition in replay memory
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
# Move to the next state
|
|
state = next_state
|
|
|
|
# Update episode reward
|
|
episode_reward += reward
|
|
|
|
# Learn from experience
|
|
if len(agent.memory) > BATCH_SIZE:
|
|
agent.learn()
|
|
|
|
# Update price predictions periodically
|
|
if step % 50 == 0:
|
|
try:
|
|
env.update_price_predictions()
|
|
env.identify_optimal_trades()
|
|
except Exception as e:
|
|
logger.warning(f"Error updating predictions: {e}")
|
|
|
|
# Add chart to TensorBoard periodically
|
|
if step % 50 == 0 or (step == max_steps_per_episode - 1) or done:
|
|
try:
|
|
global_step = episode * max_steps_per_episode + step
|
|
agent.add_chart_to_tensorboard(env, global_step)
|
|
except Exception as e:
|
|
logger.warning(f"Error adding chart to TensorBoard: {e}")
|
|
|
|
# End episode if done
|
|
if done:
|
|
break
|
|
|
|
# Update target network periodically
|
|
if episode % TARGET_UPDATE == 0:
|
|
agent.update_target_network()
|
|
|
|
# Calculate win rate
|
|
total_trades = env.win_count + env.loss_count
|
|
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
|
|
|
|
# Train price predictor
|
|
try:
|
|
if episode % 5 == 0 and len(env.data) > 50:
|
|
prediction_loss = env.train_price_predictor()
|
|
except Exception as e:
|
|
logger.warning(f"Error training price predictor: {e}")
|
|
prediction_loss = 0
|
|
|
|
# Analyze trades
|
|
try:
|
|
trade_analysis = env.analyze_trades()
|
|
stats['trade_analysis'].append(trade_analysis)
|
|
except Exception as e:
|
|
logger.warning(f"Error analyzing trades: {e}")
|
|
trade_analysis = {}
|
|
stats['trade_analysis'].append({})
|
|
|
|
# Calculate prediction accuracy
|
|
prediction_accuracy = 0.0
|
|
try:
|
|
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
|
if len(env.data) > 5:
|
|
actual_prices = [candle['close'] for candle in env.data[-5:]]
|
|
predicted = env.predicted_prices[:min(5, len(actual_prices))]
|
|
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
|
|
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating prediction accuracy: {e}")
|
|
|
|
# Log statistics
|
|
stats['episode_rewards'].append(episode_reward)
|
|
stats['episode_lengths'].append(step + 1)
|
|
stats['balances'].append(env.balance)
|
|
stats['win_rates'].append(win_rate)
|
|
stats['episode_pnls'].append(env.episode_pnl)
|
|
stats['cumulative_pnl'].append(env.total_pnl)
|
|
stats['drawdowns'].append(env.max_drawdown * 100)
|
|
stats['prediction_accuracy'].append(prediction_accuracy)
|
|
|
|
# Log detailed trade analysis
|
|
if trade_analysis:
|
|
logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
|
|
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | "
|
|
f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}")
|
|
|
|
# Log to TensorBoard
|
|
agent.writer.add_scalar('Reward/train', episode_reward, episode)
|
|
agent.writer.add_scalar('Balance/train', env.balance, episode)
|
|
agent.writer.add_scalar('WinRate/train', win_rate, episode)
|
|
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
|
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
|
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
|
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
|
agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
|
|
|
|
# Add final chart for this episode
|
|
try:
|
|
agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode)
|
|
except Exception as e:
|
|
logger.warning(f"Error adding final chart: {e}")
|
|
|
|
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
|
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
|
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
|
|
f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%")
|
|
|
|
# Save best model by reward
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
agent.save("models/trading_agent_best_reward.pt")
|
|
|
|
# Save best model by PnL
|
|
if env.episode_pnl > best_pnl:
|
|
best_pnl = env.episode_pnl
|
|
agent.save("models/trading_agent_best_pnl.pt")
|
|
|
|
# Save checkpoint
|
|
if episode % 10 == 0:
|
|
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in episode {episode}: {e}")
|
|
continue
|
|
|
|
# Save final model
|
|
agent.save("models/trading_agent_final.pt")
|
|
|
|
# Plot training results
|
|
plot_training_results(stats)
|
|
|
|
return stats
|
|
|
|
def plot_training_results(stats):
|
|
"""Plot detailed training results"""
|
|
plt.figure(figsize=(20, 15))
|
|
|
|
# Plot rewards
|
|
plt.subplot(3, 2, 1)
|
|
plt.plot(stats['episode_rewards'])
|
|
plt.title('Episode Rewards')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Reward')
|
|
|
|
# Plot balance
|
|
plt.subplot(3, 2, 2)
|
|
plt.plot(stats['balances'])
|
|
plt.title('Account Balance')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Balance ($)')
|
|
|
|
# Plot win rate
|
|
plt.subplot(3, 2, 3)
|
|
plt.plot(stats['win_rates'])
|
|
plt.title('Win Rate')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Win Rate (%)')
|
|
|
|
# Plot episode PnL
|
|
plt.subplot(3, 2, 4)
|
|
plt.plot(stats['episode_pnls'])
|
|
plt.title('Episode PnL')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('PnL ($)')
|
|
|
|
# Plot cumulative PnL
|
|
plt.subplot(3, 2, 5)
|
|
plt.plot(stats['cumulative_pnl'])
|
|
plt.title('Cumulative PnL')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Cumulative PnL ($)')
|
|
|
|
# Plot drawdown
|
|
plt.subplot(3, 2, 6)
|
|
plt.plot(stats['drawdowns'])
|
|
plt.title('Maximum Drawdown')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Drawdown (%)')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('training_results.png')
|
|
|
|
# Save statistics to CSV
|
|
df = pd.DataFrame(stats)
|
|
df.to_csv('training_stats.csv', index=False)
|
|
|
|
logger.info("Training statistics saved to training_stats.csv and training_results.png")
|
|
|
|
def evaluate_agent(agent, env, num_episodes=10):
|
|
"""Evaluate the agent on test data"""
|
|
total_reward = 0
|
|
total_profit = 0
|
|
total_trades = 0
|
|
winning_trades = 0
|
|
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
initial_balance = env.balance
|
|
|
|
done = False
|
|
while not done:
|
|
# Select action (no exploration)
|
|
action = agent.select_action(state, training=False)
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
|
|
total_reward += episode_reward
|
|
total_profit += env.balance - initial_balance
|
|
|
|
# Count trades and wins
|
|
for trade in env.trades:
|
|
if 'pnl_percent' in trade:
|
|
total_trades += 1
|
|
if trade['pnl_percent'] > 0:
|
|
winning_trades += 1
|
|
|
|
# Calculate averages
|
|
avg_reward = total_reward / num_episodes
|
|
avg_profit = total_profit / num_episodes
|
|
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
|
|
|
|
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
|
|
f"Win Rate={win_rate:.1f}%")
|
|
|
|
return avg_reward, avg_profit, win_rate
|
|
|
|
async def test_training():
|
|
"""Test the training process with a small number of episodes"""
|
|
logger.info("Starting training tests...")
|
|
|
|
# Initialize exchange
|
|
exchange = ccxt.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True,
|
|
})
|
|
|
|
try:
|
|
# Create environment with small initial balance for testing
|
|
env = TradingEnvironment(
|
|
exchange=exchange,
|
|
symbol="ETH/USDT",
|
|
timeframe="1m",
|
|
leverage=MAX_LEVERAGE,
|
|
initial_balance=100, # Small balance for testing
|
|
demo=True # Always use demo mode for testing
|
|
)
|
|
|
|
# Fetch initial data
|
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
|
|
|
# Create agent
|
|
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
|
|
|
|
# Run a few test episodes
|
|
test_episodes = 3
|
|
logger.info(f"Running {test_episodes} test episodes...")
|
|
|
|
for episode in range(test_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
done = False
|
|
step = 0
|
|
|
|
while not done and step < 100: # Limit steps for testing
|
|
# Select action
|
|
action = agent.select_action(state)
|
|
|
|
# Take action
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store experience
|
|
agent.memory.push(state, action, reward, next_state, done)
|
|
|
|
# Learn
|
|
loss = agent.learn()
|
|
|
|
state = next_state
|
|
episode_reward += reward
|
|
step += 1
|
|
|
|
# Print progress
|
|
if step % 10 == 0:
|
|
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
|
|
|
|
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
|
|
|
|
# Test model saving
|
|
try:
|
|
agent.save("models/test_model.pt")
|
|
logger.info("Successfully saved model")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
|
|
logger.info("Training tests completed successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training test failed: {e}")
|
|
return False
|
|
|
|
finally:
|
|
await exchange.close()
|
|
|
|
async def initialize_exchange():
|
|
"""Initialize the exchange connection"""
|
|
try:
|
|
# Try to initialize with async support first
|
|
try:
|
|
exchange = ccxt.pro.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True
|
|
})
|
|
logger.info(f"Exchange initialized with async support: {exchange.id}")
|
|
except (AttributeError, ImportError):
|
|
# Fall back to standard CCXT
|
|
exchange = ccxt.mexc({
|
|
'apiKey': MEXC_API_KEY,
|
|
'secret': MEXC_SECRET_KEY,
|
|
'enableRateLimit': True
|
|
})
|
|
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
|
|
|
|
return exchange
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize exchange: {e}")
|
|
raise
|
|
|
|
async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
|
"""Fetch historical OHLCV data from the exchange"""
|
|
try:
|
|
logger.info(f"Fetching historical data for {symbol}, timeframe {timeframe}, limit {limit}")
|
|
|
|
# Use the refactored fetch method
|
|
data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit)
|
|
|
|
if not data:
|
|
logger.warning("No historical data received")
|
|
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch historical data: {e}")
|
|
return []
|
|
|
|
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
|
|
"""Run the trading bot in live mode with enhanced error handling"""
|
|
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
|
|
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
|
|
|
|
# Verify agent is properly initialized
|
|
try:
|
|
# Ensure agent has all required attributes
|
|
if not hasattr(agent, 'hidden_size'):
|
|
agent.hidden_size = 256 # Default value
|
|
logger.warning("Agent missing hidden_size attribute, using default: 256")
|
|
|
|
if not hasattr(agent, 'lstm_layers'):
|
|
agent.lstm_layers = 2 # Default value
|
|
logger.warning("Agent missing lstm_layers attribute, using default: 2")
|
|
|
|
if not hasattr(agent, 'attention_heads'):
|
|
agent.attention_heads = 4 # Default value
|
|
logger.warning("Agent missing attention_heads attribute, using default: 4")
|
|
|
|
logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}")
|
|
except Exception as e:
|
|
logger.error(f"Error checking agent configuration: {e}")
|
|
# Continue anyway, as these are just informational attributes
|
|
|
|
if not demo:
|
|
# Confirm with user before starting live trading
|
|
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
|
|
if confirmation != "CONFIRM":
|
|
logger.info("Live trading canceled by user")
|
|
return
|
|
|
|
# Initialize futures trading if not in demo mode
|
|
try:
|
|
await env.initialize_futures(exchange)
|
|
logger.info(f"Futures trading initialized with {leverage}x leverage")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize futures trading: {str(e)}")
|
|
logger.info("Falling back to demo mode for safety")
|
|
demo = True
|
|
|
|
# Initialize TensorBoard for monitoring
|
|
if not hasattr(agent, 'writer') or agent.writer is None:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
# Fix the datetime usage here
|
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}')
|
|
|
|
# Track performance metrics
|
|
trades_count = 0
|
|
winning_trades = 0
|
|
total_profit = 0
|
|
max_drawdown = 0
|
|
peak_balance = env.balance
|
|
step_counter = 0
|
|
prev_position = 'flat'
|
|
|
|
# Create directory for trade logs
|
|
os.makedirs('trade_logs', exist_ok=True)
|
|
# Fix the datetime usage here
|
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
trade_log_path = f'trade_logs/trades_{current_time}.csv'
|
|
with open(trade_log_path, 'w') as f:
|
|
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
|
|
|
logger.info("Entering live trading loop...")
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
# Fetch latest candle data
|
|
candle = await get_latest_candle(exchange, symbol)
|
|
if candle is None:
|
|
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
|
|
await asyncio.sleep(5)
|
|
continue
|
|
|
|
# Add new data to environment
|
|
env.add_data(candle)
|
|
|
|
# Get current state and select action
|
|
state = env.get_state()
|
|
|
|
# Verify state shape matches agent's expected input
|
|
if state.shape[0] != agent.state_size:
|
|
logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}")
|
|
# Pad or truncate state to match expected size
|
|
if state.shape[0] < agent.state_size:
|
|
state = np.pad(state, (0, agent.state_size - state.shape[0]))
|
|
else:
|
|
state = state[:agent.state_size]
|
|
|
|
action = agent.select_action(state, training=False)
|
|
|
|
# Ensure action is valid
|
|
if action >= agent.action_size:
|
|
logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}")
|
|
action = agent.action_size - 1
|
|
|
|
# Log action
|
|
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
|
|
logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}")
|
|
|
|
# Execute action
|
|
if not demo:
|
|
# Execute real trade on exchange
|
|
current_price = env.data[-1]['close']
|
|
trade_result = await env.execute_real_trade(exchange, action, current_price)
|
|
if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False):
|
|
error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed'
|
|
logger.error(f"Trade execution failed: {error_msg}")
|
|
# Continue with simulated trade for tracking purposes
|
|
|
|
# Update environment with action (simulated in demo mode)
|
|
try:
|
|
next_state, reward, done, info = env.step(action)
|
|
except ValueError as e:
|
|
# Handle case where step returns 3 values instead of 4
|
|
if "not enough values to unpack" in str(e):
|
|
logger.warning("Step function returned 3 values instead of 4, creating info dict")
|
|
next_state, reward, done = env.step(action)
|
|
info = {
|
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
|
'price': env.current_price,
|
|
'balance': env.balance,
|
|
'position': env.position,
|
|
'pnl': env.total_pnl
|
|
}
|
|
else:
|
|
raise
|
|
|
|
# Log trade if position changed
|
|
if env.position != prev_position:
|
|
trades_count += 1
|
|
if env.last_trade_profit > 0:
|
|
winning_trades += 1
|
|
total_profit += env.last_trade_profit
|
|
|
|
# Log trade details
|
|
with open(trade_log_path, 'a') as f:
|
|
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
|
|
|
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
|
|
|
# Update performance metrics
|
|
if env.balance > peak_balance:
|
|
peak_balance = env.balance
|
|
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
|
|
if current_drawdown > max_drawdown:
|
|
max_drawdown = current_drawdown
|
|
|
|
# Update TensorBoard metrics
|
|
step_counter += 1
|
|
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
|
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
|
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
|
|
|
|
# Update chart visualization
|
|
if step_counter % 5 == 0 or env.position != prev_position:
|
|
agent.add_chart_to_tensorboard(env, step_counter)
|
|
|
|
# Log performance summary
|
|
if trades_count > 0:
|
|
win_rate = (winning_trades / trades_count) * 100
|
|
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
|
|
|
|
performance_text = f"""
|
|
**Live Trading Performance**
|
|
Balance: ${env.balance:.2f}
|
|
Total PnL: ${env.total_pnl:.2f}
|
|
Trades: {trades_count}
|
|
Win Rate: {win_rate:.1f}%
|
|
Max Drawdown: {max_drawdown*100:.1f}%
|
|
"""
|
|
agent.writer.add_text('Performance', performance_text, step_counter)
|
|
|
|
prev_position = env.position
|
|
|
|
# Wait for next candle
|
|
logger.info(f"Waiting for next candle... (Step {step_counter})")
|
|
await asyncio.sleep(10) # Check every 10 seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in live trading loop: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
logger.info("Continuing after error...")
|
|
await asyncio.sleep(30) # Wait longer after an error
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Live trading stopped by user")
|
|
|
|
# Final performance report
|
|
if trades_count > 0:
|
|
win_rate = (winning_trades / trades_count) * 100
|
|
logger.info(f"Trading session summary:")
|
|
logger.info(f"Total trades: {trades_count}")
|
|
logger.info(f"Win rate: {win_rate:.1f}%")
|
|
logger.info(f"Final balance: ${env.balance:.2f}")
|
|
logger.info(f"Total profit: ${total_profit:.2f}")
|
|
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
|
|
logger.info(f"Trade log saved to: {trade_log_path}")
|
|
|
|
async def get_latest_candle(exchange, symbol):
|
|
"""Get the latest candle data"""
|
|
try:
|
|
# Use the refactored fetch method with limit=1
|
|
data = await fetch_ohlcv_data(exchange, symbol, "1m", 1)
|
|
|
|
if data and len(data) > 0:
|
|
return data[0]
|
|
else:
|
|
logger.warning("No candle data received")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch latest candle: {e}")
|
|
return None
|
|
|
|
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
|
"""Fetch OHLCV data with proper handling for both async and standard CCXT"""
|
|
try:
|
|
# Check if exchange has fetchOHLCV method
|
|
if not hasattr(exchange, 'fetchOHLCV'):
|
|
logger.error("Exchange does not support OHLCV data fetching")
|
|
return []
|
|
|
|
# Handle different CCXT versions
|
|
if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False):
|
|
# Use async method if available
|
|
ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit)
|
|
else:
|
|
# Use synchronous method with run_in_executor
|
|
loop = asyncio.get_event_loop()
|
|
ohlcv = await loop.run_in_executor(
|
|
None,
|
|
lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
|
)
|
|
|
|
# Convert to list of dictionaries
|
|
data = []
|
|
for candle in ohlcv:
|
|
timestamp, open_price, high, low, close, volume = candle
|
|
data.append({
|
|
'timestamp': timestamp,
|
|
'open': open_price,
|
|
'high': high,
|
|
'low': low,
|
|
'close': close,
|
|
'volume': volume
|
|
})
|
|
|
|
logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})")
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch OHLCV data: {e}")
|
|
return []
|
|
|
|
async def initialize_websocket_data_stream(symbol="ETH/USDT", timeframe="1m"):
|
|
"""Initialize a WebSocket connection for real-time trading data
|
|
|
|
Args:
|
|
symbol: Trading pair symbol (e.g., "ETH/USDT")
|
|
timeframe: Timeframe for candle aggregation (e.g., "1m")
|
|
|
|
Returns:
|
|
Tuple of (websocket, candle_data) where websocket is the BinanceWebSocket instance
|
|
and candle_data is a dict to track ongoing candle formation
|
|
"""
|
|
try:
|
|
# Initialize historical data handler to get initial data
|
|
historical_data = BinanceHistoricalData()
|
|
|
|
# Convert timeframe to seconds for historical data
|
|
if timeframe == "1m":
|
|
interval_seconds = 60
|
|
elif timeframe == "5m":
|
|
interval_seconds = 300
|
|
elif timeframe == "15m":
|
|
interval_seconds = 900
|
|
elif timeframe == "1h":
|
|
interval_seconds = 3600
|
|
else:
|
|
interval_seconds = 60 # Default to 1m
|
|
|
|
# Fetch initial historical data
|
|
initial_data = historical_data.get_historical_candles(
|
|
symbol=symbol,
|
|
interval_seconds=interval_seconds,
|
|
limit=1000 # Get 1000 candles for good history
|
|
)
|
|
|
|
# Convert pandas DataFrame to list of dictionaries for our environment
|
|
initial_candles = []
|
|
if not initial_data.empty:
|
|
for _, row in initial_data.iterrows():
|
|
candle = {
|
|
'timestamp': int(row['timestamp'].timestamp() * 1000),
|
|
'open': float(row['open']),
|
|
'high': float(row['high']),
|
|
'low': float(row['low']),
|
|
'close': float(row['close']),
|
|
'volume': float(row['volume'])
|
|
}
|
|
initial_candles.append(candle)
|
|
|
|
logger.info(f"Loaded {len(initial_candles)} historical candles")
|
|
else:
|
|
logger.warning("No historical data fetched")
|
|
|
|
# Initialize WebSocket for real-time data
|
|
binance_ws = BinanceWebSocket(symbol.replace('/', ''))
|
|
await binance_ws.connect()
|
|
|
|
# Track the current candle data
|
|
current_minute = None
|
|
current_candle = None
|
|
|
|
logger.info(f"WebSocket for {symbol} initialized successfully")
|
|
return binance_ws, initial_candles
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize WebSocket data stream: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return None, []
|
|
|
|
async def process_websocket_ticks(websocket, env, agent=None, demo=True, timeframe="1m"):
|
|
"""Process real-time ticks from WebSocket and aggregate them into candles
|
|
|
|
Args:
|
|
websocket: BinanceWebSocket instance
|
|
env: TradingEnvironment instance
|
|
agent: Agent instance (optional, for live trading)
|
|
demo: Whether to run in demo mode
|
|
timeframe: Timeframe for candle aggregation
|
|
"""
|
|
# Initialize variables for candle aggregation
|
|
current_candle = None
|
|
current_minute = None
|
|
trades_count = 0
|
|
step_counter = 0
|
|
|
|
# For tracking sudden price movements
|
|
last_prices = []
|
|
price_movement_threshold = 0.5 # 0.5% movement threshold
|
|
volume_spike_threshold = 2.0 # 2x average volume
|
|
recent_volumes = []
|
|
|
|
try:
|
|
logger.info("Starting WebSocket tick processing...")
|
|
|
|
while websocket.running:
|
|
# Get the next tick from WebSocket
|
|
tick = await websocket.receive()
|
|
|
|
if tick is None:
|
|
# No data received, wait and try again
|
|
await asyncio.sleep(0.1)
|
|
continue
|
|
|
|
# Extract data from tick
|
|
timestamp = tick.get('timestamp')
|
|
price = tick.get('price')
|
|
volume = tick.get('volume')
|
|
|
|
if timestamp is None or price is None:
|
|
logger.warning(f"Invalid tick data received: {tick}")
|
|
continue
|
|
|
|
# Track price movement for significant changes
|
|
last_prices.append(price)
|
|
if len(last_prices) > 20:
|
|
last_prices.pop(0)
|
|
|
|
# Track volumes for volume spikes
|
|
recent_volumes.append(volume)
|
|
if len(recent_volumes) > 20:
|
|
recent_volumes.pop(0)
|
|
|
|
# Check for significant price movement
|
|
if len(last_prices) >= 5:
|
|
price_change_pct = abs(price - last_prices[0]) / last_prices[0] * 100
|
|
avg_volume = np.mean(recent_volumes[:-1]) if len(recent_volumes) > 1 else volume
|
|
volume_ratio = volume / avg_volume if avg_volume > 0 else 1.0
|
|
|
|
# Log significant movements
|
|
if price_change_pct > price_movement_threshold:
|
|
logger.info(f"Significant price movement detected: {price_change_pct:.2f}% change")
|
|
|
|
if volume_ratio > volume_spike_threshold:
|
|
logger.info(f"Volume spike detected: {volume_ratio:.2f}x average volume")
|
|
|
|
# Force more frequent trading decisions on significant movements
|
|
if (price_change_pct > price_movement_threshold or volume_ratio > volume_spike_threshold) and agent is not None and current_candle is not None:
|
|
# Create a temporary candle with current data
|
|
temp_candle = current_candle.copy()
|
|
temp_candle['close'] = price # Update with latest price
|
|
|
|
# Add to environment temporarily
|
|
env.add_data(temp_candle)
|
|
|
|
# Get action
|
|
state = env.get_state()
|
|
# Force exploitation (no exploration) during significant movements
|
|
action = agent.select_action(state, training=False)
|
|
|
|
# Execute action in environment
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Log trading activity
|
|
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
|
|
logger.info(f"Significant movement action: {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}")
|
|
|
|
# Convert timestamp to datetime
|
|
tick_time = datetime.datetime.fromtimestamp(timestamp / 1000)
|
|
|
|
# For 1-minute candles, track the minute
|
|
if timeframe == "1m":
|
|
tick_minute = tick_time.replace(second=0, microsecond=0)
|
|
|
|
# If this is a new minute, close the current candle and start a new one
|
|
if current_minute is None or tick_minute > current_minute:
|
|
# If there was a previous candle, add it to the environment
|
|
if current_candle is not None:
|
|
# Add the candle to the environment
|
|
env.add_data(current_candle)
|
|
|
|
# Process trading decisions if agent is provided
|
|
if agent is not None:
|
|
state = env.get_state()
|
|
action = agent.select_action(state, training=False)
|
|
|
|
# Execute action in environment
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Log trading activity
|
|
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
|
|
logger.info(f"Step {step_counter}: Action {action_name}, Price: ${price:.2f}, Balance: ${env.balance:.2f}")
|
|
step_counter += 1
|
|
|
|
# Start a new candle
|
|
current_minute = tick_minute
|
|
current_candle = {
|
|
'timestamp': int(current_minute.timestamp() * 1000),
|
|
'open': price,
|
|
'high': price,
|
|
'low': price,
|
|
'close': price,
|
|
'volume': volume
|
|
}
|
|
logger.debug(f"Started new candle at {current_minute}")
|
|
else:
|
|
# Update the current candle
|
|
current_candle['high'] = max(current_candle['high'], price)
|
|
current_candle['low'] = min(current_candle['low'], price)
|
|
current_candle['close'] = price
|
|
current_candle['volume'] += volume
|
|
|
|
# For other timeframes, implement similar logic
|
|
# ...
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("WebSocket processing canceled")
|
|
except Exception as e:
|
|
logger.error(f"Error in WebSocket tick processing: {e}")
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
# Make sure to close the WebSocket
|
|
if websocket:
|
|
await websocket.close()
|
|
logger.info("WebSocket connection closed")
|
|
|
|
# Add this near the top of the file, after imports
|
|
def ensure_pytorch_compatibility():
|
|
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
|
|
try:
|
|
import torch
|
|
from torch.serialization import add_safe_globals
|
|
import numpy as np
|
|
|
|
# Add numpy scalar to safe globals for PyTorch 2.6+
|
|
add_safe_globals(['numpy._core.multiarray.scalar'])
|
|
logger.info("Added numpy scalar to PyTorch safe globals")
|
|
except (ImportError, AttributeError) as e:
|
|
logger.warning(f"Could not configure PyTorch compatibility: {e}")
|
|
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
|
|
|
|
# Call this function at the start of the main function
|
|
async def main():
|
|
# Ensure PyTorch compatibility
|
|
ensure_pytorch_compatibility()
|
|
|
|
parser = argparse.ArgumentParser(description='Trading Bot')
|
|
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
|
|
help='Operation mode: train, eval, or live')
|
|
parser.add_argument('--episodes', type=int, default=1000,
|
|
help='Number of episodes for training or evaluation')
|
|
parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true',
|
|
help='Run in demo mode (paper trading) if true')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
|
help='Trading pair symbol')
|
|
parser.add_argument('--timeframe', type=str, default='1m',
|
|
help='Candle timeframe (1m, 5m, 15m, 1h, etc.)')
|
|
parser.add_argument('--leverage', type=int, default=50,
|
|
help='Leverage for futures trading')
|
|
parser.add_argument('--model', type=str, default=None,
|
|
help='Path to model file for evaluation or live trading')
|
|
parser.add_argument('--use-websocket', action='store_true',
|
|
help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)')
|
|
parser.add_argument('--dashboard', action='store_true',
|
|
help='Enable Dash dashboard visualization for real-time trading')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Convert string boolean to actual boolean
|
|
demo_mode = args.demo.lower() == 'true'
|
|
|
|
# Get device (GPU or CPU)
|
|
device = get_device()
|
|
|
|
exchange = None
|
|
|
|
try:
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
|
|
# Create environment
|
|
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
|
|
|
if args.mode == 'train':
|
|
# Fetch initial data for training
|
|
await env.fetch_initial_data(exchange, args.symbol,args.timeframe, 1000)
|
|
|
|
# Create agent with consistent parameters
|
|
# Note: Using STATE_SIZE and action_size=4 for consistency
|
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
|
|
|
# Train the agent
|
|
logger.info(f"Starting training for {args.episodes} episodes...")
|
|
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
|
|
|
elif args.mode == 'eval' or args.mode == 'live':
|
|
# Fetch initial data for the specified symbol and timeframe
|
|
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
|
|
|
# Determine model path
|
|
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
|
|
if not os.path.exists(model_path):
|
|
logger.error(f"Model file not found: {model_path}")
|
|
return
|
|
|
|
# Create agent with default parameters
|
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
|
|
|
# Try to load the model
|
|
try:
|
|
# Add numpy scalar to safe globals before loading
|
|
import numpy as np
|
|
from torch.serialization import add_safe_globals
|
|
|
|
# Add numpy scalar to safe globals
|
|
add_safe_globals(['numpy._core.multiarray.scalar'])
|
|
|
|
# Load the model
|
|
agent.load(model_path)
|
|
logger.info(f"Model loaded successfully from {model_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
|
|
# Ask user if they want to continue with a new model
|
|
if args.mode == 'live':
|
|
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
|
|
if confirmation.lower() != 'y':
|
|
logger.info("Live trading canceled by user")
|
|
return
|
|
logger.info("Continuing with a new model")
|
|
else:
|
|
logger.info("Continuing evaluation with a new model")
|
|
|
|
if args.mode == 'eval':
|
|
# Evaluate the agent
|
|
logger.info("Evaluating agent...")
|
|
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
|
|
|
|
elif args.mode == 'live':
|
|
# Start live trading
|
|
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
|
|
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
|
|
|
|
if args.use_websocket:
|
|
logger.info("Using Binance WebSocket for real-time data")
|
|
await live_trading_with_websocket(
|
|
agent=agent,
|
|
env=env,
|
|
symbol=args.symbol,
|
|
timeframe=args.timeframe,
|
|
demo=demo_mode,
|
|
leverage=args.leverage,
|
|
use_dashboard=args.dashboard
|
|
)
|
|
else:
|
|
logger.info("Using CCXT for real-time data")
|
|
await live_trading(
|
|
agent=agent,
|
|
env=env,
|
|
exchange=exchange,
|
|
symbol=args.symbol,
|
|
timeframe=args.timeframe,
|
|
demo=demo_mode,
|
|
leverage=args.leverage
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main function: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
# Clean up exchange connection
|
|
if exchange:
|
|
try:
|
|
if hasattr(exchange, 'close'):
|
|
await exchange.close()
|
|
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
|
await exchange.client.close()
|
|
logger.info("Exchange connection closed")
|
|
except Exception as e:
|
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
|
|
|
# Add this function near the top with other utility functions
|
|
def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
|
"""Create a candlestick chart with trade signals for TensorBoard visualization"""
|
|
if len(data) < 10:
|
|
return None
|
|
|
|
try:
|
|
# Create figure
|
|
fig = plt.figure(figsize=(12, 8))
|
|
|
|
# Prepare data for plotting
|
|
df = pd.DataFrame(data[-window_size:])
|
|
df['date'] = pd.to_datetime(df['timestamp'], unit='ms')
|
|
df.set_index('date', inplace=True)
|
|
|
|
# Create subplot grid
|
|
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
|
|
price_ax = plt.subplot(gs[0])
|
|
volume_ax = plt.subplot(gs[1], sharex=price_ax)
|
|
|
|
# Plot candlesticks - use a simpler approach if mplfinance fails
|
|
try:
|
|
# Use a different style or approach that doesn't use 'type' parameter
|
|
mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo')
|
|
except Exception as e:
|
|
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
|
|
# Fallback to simple plot
|
|
price_ax.plot(df.index, df['close'], label='Price')
|
|
volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5)
|
|
|
|
# Add trade signals
|
|
for signal in trade_signals:
|
|
try:
|
|
timestamp = pd.to_datetime(signal['timestamp'], unit='ms')
|
|
price = signal['price']
|
|
|
|
if signal['type'] == 'buy':
|
|
price_ax.plot(timestamp, price, '^', color='green', markersize=10)
|
|
elif signal['type'] == 'sell':
|
|
price_ax.plot(timestamp, price, 'v', color='red', markersize=10)
|
|
elif signal['type'] == 'close_long':
|
|
price_ax.plot(timestamp, price, 'x', color='gold', markersize=10)
|
|
elif signal['type'] == 'close_short':
|
|
price_ax.plot(timestamp, price, 'x', color='black', markersize=10)
|
|
elif 'stop_loss' in signal['type']:
|
|
price_ax.plot(timestamp, price, 'X', color='purple', markersize=10)
|
|
elif 'take_profit' in signal['type']:
|
|
price_ax.plot(timestamp, price, '*', color='cyan', markersize=10)
|
|
except Exception as e:
|
|
logger.warning(f"Error plotting signal: {e}")
|
|
continue
|
|
|
|
# Add balance and PnL annotation
|
|
if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]:
|
|
balance = trade_signals[-1]['balance']
|
|
pnl = trade_signals[-1]['pnl']
|
|
price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}",
|
|
xy=(0.02, 0.95), xycoords='axes fraction',
|
|
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
|
|
|
|
# Set title and format
|
|
price_ax.set_title(title)
|
|
fig.tight_layout()
|
|
|
|
# Convert to image
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format='png')
|
|
buf.seek(0)
|
|
plt.close(fig)
|
|
img = Image.open(buf)
|
|
return img
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating chart: {str(e)}")
|
|
return None
|
|
|
|
async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50, use_dashboard=False):
|
|
"""Run the trading bot in live mode using Binance WebSocket for real-time data
|
|
|
|
Args:
|
|
agent: The trading agent to use for decision making
|
|
env: The trading environment
|
|
symbol: The trading pair symbol (e.g., "ETH/USDT")
|
|
timeframe: The candlestick timeframe (e.g., "1m")
|
|
demo: Whether to run in demo mode (paper trading)
|
|
leverage: The leverage to use for trading
|
|
use_dashboard: Whether to display the real-time dashboard
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
logger.info(f"Starting live trading with WebSocket for {symbol} on {timeframe} timeframe")
|
|
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
|
|
|
|
# If not demo mode, confirm with user before starting live trading
|
|
if not demo:
|
|
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
|
|
if confirmation != "CONFIRM":
|
|
logger.info("Live trading canceled by user")
|
|
return
|
|
|
|
# Initialize TensorBoard for monitoring
|
|
if not hasattr(agent, 'writer') or agent.writer is None:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}')
|
|
|
|
# Initialize Dash dashboard if enabled
|
|
dashboard = None
|
|
if use_dashboard:
|
|
try:
|
|
dashboard = TradingDashboard(symbol)
|
|
dashboard_started = dashboard.start() # Start the dashboard in a separate thread
|
|
if dashboard_started:
|
|
logger.info(f"Trading dashboard enabled at http://localhost:8060")
|
|
else:
|
|
logger.warning("Failed to start trading dashboard, continuing without visualization")
|
|
dashboard = None
|
|
except Exception as e:
|
|
logger.error(f"Error initializing dashboard: {e}")
|
|
logger.error(traceback.format_exc())
|
|
dashboard = None
|
|
|
|
# Track performance metrics
|
|
trades_count = 0
|
|
winning_trades = 0
|
|
total_profit = 0
|
|
max_drawdown = 0
|
|
peak_balance = env.balance
|
|
step_counter = 0
|
|
|
|
# Create directory for trade logs
|
|
os.makedirs('trade_logs', exist_ok=True)
|
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
trade_log_path = f'trade_logs/trades_ws_{current_time}.csv'
|
|
with open(trade_log_path, 'w') as f:
|
|
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
|
|
|
try:
|
|
# Initialize WebSocket connection and get historical data
|
|
websocket, initial_candles = await initialize_websocket_data_stream(symbol, timeframe)
|
|
|
|
if websocket is None or not initial_candles:
|
|
logger.error("Failed to initialize WebSocket data stream")
|
|
return
|
|
|
|
# Load initial historical data into the environment
|
|
logger.info(f"Loading {len(initial_candles)} initial candles into environment")
|
|
for candle in initial_candles:
|
|
env.add_data(candle)
|
|
|
|
# Reset environment with historical data
|
|
env.reset()
|
|
|
|
# Update dashboard with initial data if enabled
|
|
if dashboard:
|
|
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
|
|
|
|
# Initialize futures trading if not in demo mode
|
|
exchange = None
|
|
if not demo:
|
|
# Import ccxt for exchange initialization
|
|
import ccxt.async_support as ccxt_async
|
|
|
|
# Initialize exchange for order execution
|
|
exchange = await initialize_exchange()
|
|
if exchange:
|
|
try:
|
|
await env.initialize_futures(exchange)
|
|
logger.info(f"Futures trading initialized with {leverage}x leverage")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize futures trading: {str(e)}")
|
|
logger.info("Falling back to demo mode for safety")
|
|
demo = True
|
|
|
|
# Start WebSocket processing in the background
|
|
websocket_task = asyncio.create_task(
|
|
process_websocket_ticks(websocket, env, agent, demo, timeframe)
|
|
)
|
|
|
|
# Main tracking loop
|
|
prev_position = 'flat'
|
|
while True:
|
|
try:
|
|
# Check if position has changed
|
|
if env.position != prev_position:
|
|
trades_count += 1
|
|
if hasattr(env, 'last_trade_profit') and env.last_trade_profit > 0:
|
|
winning_trades += 1
|
|
if hasattr(env, 'last_trade_profit'):
|
|
total_profit += env.last_trade_profit
|
|
|
|
# Log trade details
|
|
current_time = datetime.datetime.now().isoformat()
|
|
action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE"
|
|
with open(trade_log_path, 'a') as f:
|
|
f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n")
|
|
|
|
logger.info(f"Trade executed: {action_name} at ${env.current_price:.2f}, PnL: ${getattr(env, 'last_trade_profit', 0):.2f}")
|
|
|
|
# Update performance metrics
|
|
if env.balance > peak_balance:
|
|
peak_balance = env.balance
|
|
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
|
|
if current_drawdown > max_drawdown:
|
|
max_drawdown = current_drawdown
|
|
|
|
# Update TensorBoard metrics
|
|
step_counter += 1
|
|
if step_counter % 10 == 0: # Update every 10 steps
|
|
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
|
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
|
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
|
|
|
|
# Update chart visualization
|
|
if step_counter % 30 == 0 or env.position != prev_position:
|
|
agent.add_chart_to_tensorboard(env, step_counter)
|
|
|
|
# Log performance summary
|
|
if trades_count > 0:
|
|
win_rate = (winning_trades / trades_count) * 100
|
|
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
|
|
|
|
performance_text = f"""
|
|
**Live Trading Performance**
|
|
Balance: ${env.balance:.2f}
|
|
Total PnL: ${env.total_pnl:.2f}
|
|
Trades: {trades_count}
|
|
Win Rate: {win_rate:.1f}%
|
|
Max Drawdown: {max_drawdown*100:.1f}%
|
|
"""
|
|
agent.writer.add_text('Performance', performance_text, step_counter)
|
|
|
|
# Update the dashboard with latest data if enabled
|
|
if dashboard:
|
|
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
|
|
|
|
prev_position = env.position
|
|
|
|
# Sleep for a short time to prevent CPU hogging
|
|
await asyncio.sleep(1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in live trading monitor loop: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
await asyncio.sleep(10) # Wait longer after an error
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Live trading stopped by user")
|
|
|
|
# Cancel the WebSocket task
|
|
if 'websocket_task' in locals() and not websocket_task.done():
|
|
websocket_task.cancel()
|
|
try:
|
|
await websocket_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Close the exchange connection if it exists
|
|
if exchange:
|
|
await exchange.close()
|
|
|
|
# Final performance report
|
|
if trades_count > 0:
|
|
win_rate = (winning_trades / trades_count) * 100
|
|
logger.info(f"Trading session summary:")
|
|
logger.info(f"Total trades: {trades_count}")
|
|
logger.info(f"Win rate: {win_rate:.1f}%")
|
|
logger.info(f"Final balance: ${env.balance:.2f}")
|
|
logger.info(f"Total profit: ${total_profit:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Critical error in live trading: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
finally:
|
|
# Make sure to close WebSocket
|
|
if 'websocket' in locals() and websocket:
|
|
await websocket.close()
|
|
|
|
# Close the exchange connection if it exists
|
|
if 'exchange' in locals() and exchange:
|
|
await exchange.close()
|
|
|
|
def ensure_pytorch_compatibility():
|
|
"""Check and fix common PyTorch compatibility issues"""
|
|
try:
|
|
import torch.serialization
|
|
import pickle
|
|
|
|
# Register safe pickles to handle the numpy scalar warning
|
|
if hasattr(torch.serialization, 'add_safe_globals'):
|
|
torch.serialization.add_safe_globals([('numpy._core.multiarray.scalar', np.ndarray)])
|
|
torch.serialization.add_safe_globals([('numpy.core.multiarray.scalar', np.ndarray)])
|
|
torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])
|
|
torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar'])
|
|
|
|
logger.info("PyTorch safe globals registered for compatibility")
|
|
else:
|
|
logger.warning("PyTorch serialization module doesn't have add_safe_globals method")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"PyTorch compatibility check failed: {e}")
|
|
|
|
|
|
class TradingDashboard:
|
|
"""Dashboard for visualizing trading activity with Dash"""
|
|
|
|
def __init__(self, symbol="ETH/USDT"):
|
|
self.symbol = symbol
|
|
self.env = None
|
|
self.candles = []
|
|
self.trade_signals = []
|
|
|
|
# Create Dash app
|
|
self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
|
|
|
# Create basic layout
|
|
self.app.layout = html.Div([
|
|
# Store components for data
|
|
html.Div(id='candle-store', style={'display': 'none'}),
|
|
html.Div(id='signal-store', style={'display': 'none'}),
|
|
|
|
# Header
|
|
html.H1(f"Trading Dashboard - {symbol}", style={'textAlign': 'center'}),
|
|
|
|
# Main content
|
|
html.Div([
|
|
# Chart
|
|
html.Div([
|
|
dcc.Graph(id='candlestick-chart', style={'height': '70vh'}),
|
|
dcc.Interval(id='interval-component', interval=5*1000, n_intervals=0)
|
|
], style={'width': '70%', 'display': 'inline-block'}),
|
|
|
|
# Trading info
|
|
html.Div([
|
|
html.Div([
|
|
html.H3("Account Info"),
|
|
html.Div(id='account-info')
|
|
]),
|
|
html.Div([
|
|
html.H3("Recent Trades"),
|
|
html.Div(id='recent-trades')
|
|
])
|
|
], style={'width': '30%', 'display': 'inline-block', 'verticalAlign': 'top'})
|
|
])
|
|
])
|
|
|
|
# Setup callbacks
|
|
self._setup_callbacks()
|
|
|
|
# Thread for running the server
|
|
self.thread = None
|
|
self.is_running = False
|
|
|
|
def _setup_callbacks(self):
|
|
@self.app.callback(
|
|
Output('candlestick-chart', 'figure'),
|
|
[Input('interval-component', 'n_intervals'),
|
|
Input('candle-store', 'children'),
|
|
Input('signal-store', 'children')]
|
|
)
|
|
def update_chart(n, candles_json, signals_json):
|
|
# Parse JSON data
|
|
candles = json.loads(candles_json) if candles_json else []
|
|
signals = json.loads(signals_json) if signals_json else []
|
|
|
|
# Create figure with subplots
|
|
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
|
|
vertical_spacing=0.1, row_heights=[0.7, 0.3])
|
|
|
|
if candles:
|
|
# Convert to dataframe
|
|
df = pd.DataFrame(candles[-100:]) # Show last 100 candles
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
|
|
|
# Add candlestick trace
|
|
fig.add_trace(
|
|
go.Candlestick(
|
|
x=df['timestamp'],
|
|
open=df['open'],
|
|
high=df['high'],
|
|
low=df['low'],
|
|
close=df['close'],
|
|
name='Price'
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
# Add volume trace
|
|
fig.add_trace(
|
|
go.Bar(
|
|
x=df['timestamp'],
|
|
y=df['volume'],
|
|
name='Volume'
|
|
),
|
|
row=2, col=1
|
|
)
|
|
|
|
# Add trade signals
|
|
for signal in signals:
|
|
if signal['timestamp'] >= df['timestamp'].iloc[0].timestamp() * 1000:
|
|
signal_time = pd.to_datetime(signal['timestamp'], unit='ms')
|
|
marker_color = 'green' if signal['type'] == 'buy' else 'red' if signal['type'] == 'sell' else 'orange'
|
|
marker_symbol = 'triangle-up' if signal['type'] == 'buy' else 'triangle-down' if signal['type'] == 'sell' else 'circle'
|
|
|
|
# Add marker for signal
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=[signal_time],
|
|
y=[signal['price']],
|
|
mode='markers',
|
|
marker=dict(
|
|
color=marker_color,
|
|
size=12,
|
|
symbol=marker_symbol
|
|
),
|
|
name=signal['type'].capitalize(),
|
|
showlegend=False
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
# Update layout
|
|
fig.update_layout(
|
|
title=f'{self.symbol} Trading Chart',
|
|
xaxis_rangeslider_visible=False,
|
|
template='plotly_dark'
|
|
)
|
|
|
|
return fig
|
|
|
|
@self.app.callback(
|
|
[Output('account-info', 'children'),
|
|
Output('recent-trades', 'children')],
|
|
[Input('interval-component', 'n_intervals')]
|
|
)
|
|
def update_account_info(n):
|
|
if not self.env:
|
|
return "No data available", "No trades available"
|
|
|
|
# Account info
|
|
account_info = html.Div([
|
|
html.P(f"Balance: ${self.env.balance:.2f}"),
|
|
html.P(f"PnL: ${self.env.total_pnl:.2f}",
|
|
style={'color': 'green' if self.env.total_pnl > 0 else 'red' if self.env.total_pnl < 0 else 'white'}),
|
|
html.P(f"Position: {self.env.position.upper()}")
|
|
])
|
|
|
|
# Recent trades
|
|
if hasattr(self.env, 'trades') and self.env.trades:
|
|
# Get last 5 trades
|
|
recent_trades = []
|
|
for trade in reversed(self.env.trades[-5:]):
|
|
trade_card = html.Div([
|
|
html.P(f"{trade['action'].upper()} at ${trade['price']:.2f}"),
|
|
html.P(f"PnL: ${trade['pnl']:.2f}",
|
|
style={'color': 'green' if trade['pnl'] > 0 else 'red' if trade['pnl'] < 0 else 'white'})
|
|
], style={'border': '1px solid #ddd', 'padding': '10px', 'margin-bottom': '5px'})
|
|
recent_trades.append(trade_card)
|
|
else:
|
|
recent_trades = [html.P("No trades yet")]
|
|
|
|
return account_info, recent_trades
|
|
|
|
def update_data(self, env=None, candles=None, trade_signals=None):
|
|
"""Update dashboard data"""
|
|
if env:
|
|
self.env = env
|
|
|
|
if candles:
|
|
self.candles = candles
|
|
|
|
if trade_signals:
|
|
self.trade_signals = trade_signals
|
|
|
|
# Update store components
|
|
if hasattr(self.app, 'layout'):
|
|
self.app.layout.children[0].children = json.dumps(self.candles)
|
|
self.app.layout.children[1].children = json.dumps(self.trade_signals)
|
|
|
|
def start(self, host='localhost', port=8060):
|
|
"""Start the dashboard server in a separate thread"""
|
|
if not self.is_running:
|
|
# First check if the port is already in use
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
port_available = False
|
|
|
|
# Try the initial port and a few alternatives if needed
|
|
for attempt_port in range(port, port + 10):
|
|
try:
|
|
sock.bind((host, attempt_port))
|
|
port_available = True
|
|
port = attempt_port
|
|
break
|
|
except socket.error:
|
|
logger.warning(f"Port {attempt_port} is already in use")
|
|
sock.close()
|
|
|
|
if not port_available:
|
|
logger.error("Could not find an available port for dashboard")
|
|
return False
|
|
|
|
# Create and start the thread
|
|
self.thread = Thread(target=self._run_server, args=(host, port))
|
|
self.thread.daemon = True # This ensures the thread will exit when the main program does
|
|
self.thread.start()
|
|
self.is_running = True
|
|
logger.info(f"Trading dashboard started at http://{host}:{port}")
|
|
|
|
# Verify the thread actually started
|
|
if not self.thread.is_alive():
|
|
logger.error("Dashboard thread failed to start")
|
|
return False
|
|
|
|
# Wait a short time to let the server initialize
|
|
time.sleep(1.0)
|
|
return True
|
|
return False
|
|
|
|
def _run_server(self, host, port):
|
|
"""Run the Dash server"""
|
|
try:
|
|
logger.info(f"Starting Dash server on {host}:{port}")
|
|
self.app.run_server(debug=False, host=host, port=port, use_reloader=False, threaded=True)
|
|
except Exception as e:
|
|
logger.error(f"Error running dashboard server: {e}")
|
|
self.is_running = False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Program terminated by user") |