new bot 2
This commit is contained in:
parent
c8b0f77d32
commit
9b6d3f94ed
111
Niki/GPT/pine/tsla-1m-winning.pine
Normal file
111
Niki/GPT/pine/tsla-1m-winning.pine
Normal file
@ -0,0 +1,111 @@
|
||||
//@version=6
|
||||
strategy("Aggressive Bear Market Short Strategy - V6 (Aggressive Conditions)", overlay=true, initial_capital=10000, currency=currency.USD, default_qty_type=strategy.percent_of_equity, default_qty_value=2) // Reduced position size
|
||||
|
||||
// === INPUTS ===
|
||||
// Trend Confirmation: Simple Moving Average
|
||||
smaPeriod = input.int(title="SMA Period", defval=50, minval=1)
|
||||
|
||||
// RSI Parameters
|
||||
rsiPeriod = input.int(title="RSI Period", defval=14, minval=1)
|
||||
rsiAggThreshold = input.int(title="Aggressive RSI Threshold", defval=50, minval=1, maxval=100)
|
||||
|
||||
// MACD Parameters
|
||||
macdFast = input.int(title="MACD Fast Length", defval=12, minval=1)
|
||||
macdSlow = input.int(title="MACD Slow Length", defval=26, minval=1)
|
||||
macdSignalL = input.int(title="MACD Signal Length", defval=9, minval=1)
|
||||
|
||||
// Bollinger Bands Parameters
|
||||
bbLength = input.int(title="Bollinger Bands Length", defval=20, minval=1)
|
||||
bbStdDev = input.float(title="BB StdDev Multiplier", defval=2.0, step=0.1)
|
||||
|
||||
// Stochastic Oscillator Parameters
|
||||
stochLength = input.int(title="Stochastic %K Length", defval=14, minval=1)
|
||||
stochSmooth = input.int(title="Stochastic %D Smoothing", defval=3, minval=1)
|
||||
stochAggThreshold = input.int(title="Aggressive Stochastic Threshold", defval=70, minval=1, maxval=100)
|
||||
|
||||
// ADX Parameters
|
||||
adxPeriod = input.int(title="ADX Period", defval=14, minval=1)
|
||||
adxAggThreshold = input.float(title="Aggressive ADX Threshold", defval=20.0, step=0.1)
|
||||
|
||||
// Risk Management
|
||||
stopLossPercent = input.float(title="Stop Loss (%)", defval=0.5, step=0.1)
|
||||
takeProfitPercent = input.float(title="Take Profit (%)", defval=0.3, step=0.1)
|
||||
trailingStopPercent = input.float(title="Trailing Stop (%)", defval=0.3, step=0.1)
|
||||
|
||||
// === INDICATOR CALCULATIONS ===
|
||||
|
||||
// 1. SMA for overall trend determination.
|
||||
smaValue = ta.sma(close, smaPeriod)
|
||||
|
||||
// 2. RSI calculation.
|
||||
rsiValue = ta.rsi(close, rsiPeriod)
|
||||
|
||||
// 3. MACD calculation.
|
||||
[macdLine, signalLine, _] = ta.macd(close, macdFast, macdSlow, macdSignalL)
|
||||
|
||||
// 4. Bollinger Bands calculation.
|
||||
bbBasis = ta.sma(close, bbLength)
|
||||
bbDev = bbStdDev * ta.stdev(close, bbLength)
|
||||
bbUpper = bbBasis + bbDev
|
||||
bbLower = bbBasis - bbDev
|
||||
|
||||
// 5. Stochastic Oscillator calculation.
|
||||
k = ta.stoch(close, high, low, stochLength)
|
||||
d = ta.sma(k, stochSmooth)
|
||||
|
||||
// 6. ADX calculation.
|
||||
[diPlus, diMinus, adxValue] = ta.adx(high, low, close, adxPeriod) // Using built-in function
|
||||
|
||||
// === AGGRESSIVE SIGNAL CONDITIONS ===
|
||||
|
||||
// Mandatory Bearish Condition: Price must be below the SMA.
|
||||
bearTrend = close < smaValue
|
||||
|
||||
// Aggressive MACD Condition
|
||||
macdSignalFlag = macdLine < signalLine
|
||||
|
||||
// Aggressive RSI Condition
|
||||
rsiSignalFlag = rsiValue > rsiAggThreshold
|
||||
|
||||
// Aggressive Bollinger Bands Condition
|
||||
bbSignalFlag = close > bbUpper
|
||||
|
||||
// Aggressive Stochastic Condition
|
||||
stochSignalFlag = ta.crossunder(k, stochAggThreshold)
|
||||
|
||||
// Aggressive ADX Condition
|
||||
adxSignalFlag = adxValue > adxAggThreshold
|
||||
|
||||
// Count the number of indicator signals that are true (Weighted).
|
||||
signalWeight = 0.0
|
||||
if macdSignalFlag
|
||||
signalWeight := signalWeight + 0.25
|
||||
if rsiSignalFlag
|
||||
signalWeight := signalWeight + 0.15
|
||||
if bbSignalFlag
|
||||
signalWeight := signalWeight + 0.2
|
||||
if stochSignalFlag
|
||||
signalWeight := signalWeight + 0.15
|
||||
if adxSignalFlag
|
||||
signalWeight := signalWeight + 0.25
|
||||
|
||||
// Take a short position if the bear market condition is met and the signal weight is high enough.
|
||||
if bearTrend and (signalWeight >= 0.5)
|
||||
strategy.entry("Short", strategy.short)
|
||||
|
||||
// === EXIT CONDITIONS ===
|
||||
// Dynamic Trailing Stop Loss
|
||||
if strategy.position_size < 0
|
||||
strategy.exit("Exit Short", from_entry = "Short", stop = math.max(strategy.position_avg_price * (1 + stopLossPercent / 100), high - high * trailingStopPercent / 100), limit= strategy.position_avg_price * (1 - takeProfitPercent / 100))
|
||||
|
||||
|
||||
// === PLOTTING ===
|
||||
plot(smaValue, color=color.orange, title="SMA")
|
||||
plot(bbUpper, color=color.blue, title="Bollinger Upper Band")
|
||||
plot(bbBasis, color=color.gray, title="Bollinger Basis")
|
||||
plot(bbLower, color=color.blue, title="Bollinger Lower Band")
|
||||
plot(adxValue, title="ADX", color=color.fuchsia)
|
||||
|
||||
// Optional: Plot RSI and a horizontal line at the aggressive RSI threshold.
|
||||
plot(rsiValue, title="RSI", color=color.purple)
|
||||
hline(rsiAggThreshold, title="Aggressive RSI Threshold", color=color.red)
|
2
crypto/gogo/_prompts.md
Normal file
2
crypto/gogo/_prompts.md
Normal file
@ -0,0 +1,2 @@
|
||||
let's extend that to have 32 more values - it will be added later but we need our model architecture to support it.
|
||||
we'd also want to have 5 different timeframes at once: 1s(ticks - probably only price and emas), 1m,15m, 1h and 1d. each module will accept all the data, but will produce prediction only for it's own timeframe
|
@ -9,6 +9,10 @@ import ccxt.async_support as ccxt
|
||||
from dotenv import load_dotenv
|
||||
import platform
|
||||
|
||||
# Set Windows event loop policy at module level
|
||||
if platform.system() == 'Windows':
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
class LiveDataManager:
|
||||
def __init__(self, symbol, exchange_name='mexc', window_size=120):
|
||||
load_dotenv() # Load environment variables
|
||||
@ -45,12 +49,16 @@ class LiveDataManager:
|
||||
retries = 3
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
candles = await self.exchange.fetch_ohlcv(self.symbol, '1m', since=since, limit=self.window_size)
|
||||
candles = await self.exchange.fetch_ohlcv(
|
||||
self.symbol, '1m', since=since, limit=self.window_size
|
||||
)
|
||||
for candle in candles:
|
||||
self.candles.append(self._format_candle(candle))
|
||||
if candles:
|
||||
self.last_candle_time = candles[-1][0]
|
||||
print(f"Fetched {len(candles)} initial candles.")
|
||||
print(f"""Fetched {len(candles)} initial candles for period {since} to {now}.
|
||||
Price range: {min(candle[1] for candle in candles)} to {max(candle[2] for candle in candles)}.
|
||||
Current price: {candles[-1][4]}. Total volume: {sum(candle[5] for candle in candles)}""")
|
||||
return # Exit the function if successful
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1} failed: {e}")
|
||||
|
@ -133,6 +133,7 @@ class Transformer(nn.Module):
|
||||
def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, dropout=0.1):
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.candle_embedding = nn.Linear(input_dim, d_model)
|
||||
self.tick_embedding = nn.Linear(2, d_model) # Each tick has price and quantity
|
||||
|
||||
@ -152,10 +153,11 @@ class Transformer(nn.Module):
|
||||
self.future_ticks_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
|
||||
self.future_ticks_projection = nn.Linear(d_model, 60) # 30 ticks * (price, quantity) = 60
|
||||
|
||||
def forward(self, candle_data, tick_data, future_candle_mask, future_ticks_mask):
|
||||
# candle_data: [batch_size, seq_len, input_dim]
|
||||
# tick_data: [batch_size, tick_seq_len, 2]
|
||||
|
||||
def forward(self, candle_data, tick_data, future_candle_mask=None, future_ticks_mask=None):
|
||||
# Print shapes for debugging
|
||||
# print(f"Candle data shape: {candle_data.shape}, Expected input dim: {self.input_dim}")
|
||||
|
||||
# Embed candle data
|
||||
candle_embedded = self.candle_embedding(candle_data)
|
||||
candle_embedded = self.positional_encoding(candle_embedded) # Add positional info
|
||||
|
||||
@ -189,7 +191,7 @@ class Transformer(nn.Module):
|
||||
|
||||
# Example instantiation (adjust parameters for ~1B parameters)
|
||||
if __name__ == '__main__':
|
||||
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
|
||||
input_dim = 11 # Changed from 12 to 11 to match your data
|
||||
d_model = 512 # Hidden dimension
|
||||
num_heads = 8
|
||||
num_layers = 6 # Number of encoder/decoder layers
|
||||
@ -220,3 +222,22 @@ if __name__ == '__main__':
|
||||
print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5]
|
||||
print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1]
|
||||
print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2]
|
||||
|
||||
# Make sure to use this when instantiating the model
|
||||
def create_model(input_dim=11):
|
||||
d_model = 512 # Hidden dimension
|
||||
num_heads = 8
|
||||
num_layers = 6 # Number of encoder/decoder layers
|
||||
d_ff = 2048 # Feedforward dimension
|
||||
dropout = 0.1
|
||||
|
||||
model = Transformer(
|
||||
input_dim=input_dim,
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout
|
||||
)
|
||||
|
||||
return model
|
||||
|
41
crypto/gogo2/.vscode/launch.json
vendored
Normal file
41
crypto/gogo2/.vscode/launch.json
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Train Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "train", "--episodes", "1000"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Evaluate Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "eval", "--episodes", "10"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Live Trading (Demo)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "live", "--demo"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Live Trading (Real)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "live"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
953
crypto/gogo2/main.py
Normal file
953
crypto/gogo2/main.py
Normal file
@ -0,0 +1,953 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import datetime
|
||||
import random
|
||||
import logging
|
||||
import asyncio
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from collections import deque, namedtuple
|
||||
from dotenv import load_dotenv
|
||||
import ccxt
|
||||
import websockets
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()]
|
||||
)
|
||||
logger = logging.getLogger("trading_bot")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
||||
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
# Constants
|
||||
INITIAL_BALANCE = 100 # USD
|
||||
MAX_LEVERAGE = 100
|
||||
STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage
|
||||
TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5%
|
||||
MEMORY_SIZE = 100000
|
||||
BATCH_SIZE = 64
|
||||
GAMMA = 0.99 # Discount factor
|
||||
EPSILON_START = 1.0
|
||||
EPSILON_END = 0.05
|
||||
EPSILON_DECAY = 10000
|
||||
STATE_SIZE = 40 # Size of our state representation
|
||||
LEARNING_RATE = 1e-4
|
||||
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
||||
|
||||
# Experience replay tuple
|
||||
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
|
||||
|
||||
class ReplayMemory:
|
||||
def __init__(self, capacity):
|
||||
self.memory = deque(maxlen=capacity)
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
self.memory.append(Experience(state, action, reward, next_state, done))
|
||||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
class DQN(nn.Module):
|
||||
def __init__(self, state_size, action_size):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
# Larger architecture for more complex pattern recognition
|
||||
self.fc1 = nn.Linear(state_size, 256)
|
||||
self.bn1 = nn.BatchNorm1d(256)
|
||||
|
||||
# LSTM layer for sequential data
|
||||
self.lstm = nn.LSTM(256, 256, num_layers=2, batch_first=True)
|
||||
|
||||
# Attention mechanism
|
||||
self.attention = nn.MultiheadAttention(256, 4)
|
||||
|
||||
# Output layers
|
||||
self.fc2 = nn.Linear(256, 128)
|
||||
self.bn2 = nn.BatchNorm1d(128)
|
||||
self.fc3 = nn.Linear(128, 64)
|
||||
self.fc4 = nn.Linear(64, action_size)
|
||||
|
||||
# Dueling DQN architecture
|
||||
self.value_stream = nn.Linear(64, 1)
|
||||
self.advantage_stream = nn.Linear(64, action_size)
|
||||
|
||||
def forward(self, x):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0) # Add batch dimension if needed
|
||||
|
||||
# Initial feature extraction
|
||||
x = F.relu(self.bn1(self.fc1(x)))
|
||||
|
||||
# Process sequential data through LSTM
|
||||
x = x.unsqueeze(0) if x.dim() == 2 else x # Add sequence dimension if needed
|
||||
x, _ = self.lstm(x)
|
||||
x = x.squeeze(0) if x.dim() == 3 else x # Remove sequence dimension if only one item
|
||||
|
||||
# Self-attention
|
||||
x_reshaped = x.unsqueeze(1) if x.dim() == 2 else x
|
||||
attn_output, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
||||
x = attn_output.squeeze(1) if x.dim() == 3 else attn_output
|
||||
|
||||
# Final layers
|
||||
x = F.relu(self.bn2(self.fc2(x)))
|
||||
x = F.relu(self.fc3(x))
|
||||
|
||||
# Dueling architecture
|
||||
value = self.value_stream(x)
|
||||
advantages = self.advantage_stream(x)
|
||||
qvals = value + (advantages - advantages.mean(dim=1, keepdim=True))
|
||||
|
||||
return qvals
|
||||
|
||||
class TradingEnvironment:
|
||||
def __init__(self, exchange, symbol="ETH/USDT", timeframe="1m", leverage=MAX_LEVERAGE,
|
||||
initial_balance=INITIAL_BALANCE, window_size=60, is_demo=True):
|
||||
self.exchange = exchange
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.leverage = leverage
|
||||
self.balance = initial_balance
|
||||
self.initial_balance = initial_balance
|
||||
self.window_size = window_size
|
||||
self.is_demo = is_demo
|
||||
|
||||
self.position = None # 'long', 'short', or None
|
||||
self.entry_price = 0.0
|
||||
self.position_size = 0.0
|
||||
self.stop_loss = 0.0
|
||||
self.take_profit = 0.0
|
||||
|
||||
self.data = deque(maxlen=window_size)
|
||||
self.trades = []
|
||||
self.current_step = 0
|
||||
|
||||
# Action space: 0 = hold, 1 = buy, 2 = sell, 3 = close
|
||||
self.action_space = 4
|
||||
|
||||
self._initialize_features()
|
||||
|
||||
def _initialize_features(self):
|
||||
"""Initialize technical indicators and features"""
|
||||
self.features = {
|
||||
'price': [],
|
||||
'volume': [],
|
||||
'rsi': [],
|
||||
'macd': [],
|
||||
'macd_signal': [],
|
||||
'macd_hist': [],
|
||||
'bollinger_upper': [],
|
||||
'bollinger_lower': [],
|
||||
'bollinger_mid': [],
|
||||
'atr': [],
|
||||
'ema_fast': [],
|
||||
'ema_slow': [],
|
||||
'stoch_k': [],
|
||||
'stoch_d': [],
|
||||
'mom': [] # Momentum
|
||||
}
|
||||
|
||||
async def fetch_initial_data(self):
|
||||
"""Fetch historical data to initialize the environment"""
|
||||
logger.info(f"Fetching initial {self.window_size} candles for {self.symbol}...")
|
||||
try:
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
self.symbol,
|
||||
timeframe=self.timeframe,
|
||||
limit=self.window_size
|
||||
)
|
||||
|
||||
for candle in ohlcv:
|
||||
timestamp, open_price, high, low, close, volume = candle
|
||||
self.data.append({
|
||||
'timestamp': timestamp,
|
||||
'open': open_price,
|
||||
'high': high,
|
||||
'low': low,
|
||||
'close': close,
|
||||
'volume': volume
|
||||
})
|
||||
|
||||
self._update_features()
|
||||
logger.info(f"Successfully fetched {len(self.data)} initial candles")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching initial data: {e}")
|
||||
raise
|
||||
|
||||
def _update_features(self):
|
||||
"""Calculate technical indicators from price data"""
|
||||
if len(self.data) < 14: # Need minimum data for indicators
|
||||
return
|
||||
|
||||
df = pd.DataFrame(list(self.data))
|
||||
|
||||
# Basic price and volume
|
||||
self.features['price'] = df['close'].values
|
||||
self.features['volume'] = df['volume'].values
|
||||
|
||||
# EMAs
|
||||
self.features['ema_fast'] = df['close'].ewm(span=12, adjust=False).mean().values
|
||||
self.features['ema_slow'] = df['close'].ewm(span=26, adjust=False).mean().values
|
||||
|
||||
# RSI
|
||||
delta = df['close'].diff()
|
||||
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
|
||||
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
self.features['rsi'] = rsi.fillna(50).values
|
||||
|
||||
# MACD
|
||||
ema12 = df['close'].ewm(span=12, adjust=False).mean()
|
||||
ema26 = df['close'].ewm(span=26, adjust=False).mean()
|
||||
macd = ema12 - ema26
|
||||
macd_signal = macd.ewm(span=9, adjust=False).mean()
|
||||
macd_hist = macd - macd_signal
|
||||
|
||||
self.features['macd'] = macd.values
|
||||
self.features['macd_signal'] = macd_signal.values
|
||||
self.features['macd_hist'] = macd_hist.values
|
||||
|
||||
# Bollinger Bands
|
||||
sma20 = df['close'].rolling(window=20).mean()
|
||||
std20 = df['close'].rolling(window=20).std()
|
||||
upper_band = sma20 + (std20 * 2)
|
||||
lower_band = sma20 - (std20 * 2)
|
||||
|
||||
self.features['bollinger_upper'] = upper_band.fillna(method='bfill').values
|
||||
self.features['bollinger_mid'] = sma20.fillna(method='bfill').values
|
||||
self.features['bollinger_lower'] = lower_band.fillna(method='bfill').values
|
||||
|
||||
# ATR (Average True Range)
|
||||
high_low = df['high'] - df['low']
|
||||
high_close = (df['high'] - df['close'].shift()).abs()
|
||||
low_close = (df['low'] - df['close'].shift()).abs()
|
||||
ranges = pd.concat([high_low, high_close, low_close], axis=1)
|
||||
true_range = ranges.max(axis=1)
|
||||
atr = true_range.rolling(window=14).mean()
|
||||
|
||||
self.features['atr'] = atr.fillna(method='bfill').values
|
||||
|
||||
# Stochastic Oscillator
|
||||
low_min = df['low'].rolling(window=14).min()
|
||||
high_max = df['high'].rolling(window=14).max()
|
||||
|
||||
k = 100 * ((df['close'] - low_min) / (high_max - low_min))
|
||||
d = k.rolling(window=3).mean()
|
||||
|
||||
self.features['stoch_k'] = k.fillna(50).values
|
||||
self.features['stoch_d'] = d.fillna(50).values
|
||||
|
||||
# Momentum
|
||||
self.features['mom'] = df['close'].diff(periods=10).values
|
||||
|
||||
async def _update_with_new_data(self, candle):
|
||||
"""Update environment with new candle data"""
|
||||
self.data.append(candle)
|
||||
self._update_features()
|
||||
self._check_position()
|
||||
|
||||
def _check_position(self):
|
||||
"""Check if stop loss or take profit has been hit"""
|
||||
if self.position is None or len(self.features['price']) == 0:
|
||||
return
|
||||
|
||||
current_price = self.features['price'][-1]
|
||||
|
||||
if self.position == 'long':
|
||||
# Check stop loss
|
||||
if current_price <= self.stop_loss:
|
||||
logger.info(f"STOP LOSS triggered at {current_price} (long position)")
|
||||
self._close_position(current_price, 'stop_loss')
|
||||
|
||||
# Check take profit
|
||||
elif current_price >= self.take_profit:
|
||||
logger.info(f"TAKE PROFIT triggered at {current_price} (long position)")
|
||||
self._close_position(current_price, 'take_profit')
|
||||
|
||||
elif self.position == 'short':
|
||||
# Check stop loss
|
||||
if current_price >= self.stop_loss:
|
||||
logger.info(f"STOP LOSS triggered at {current_price} (short position)")
|
||||
self._close_position(current_price, 'stop_loss')
|
||||
|
||||
# Check take profit
|
||||
elif current_price <= self.take_profit:
|
||||
logger.info(f"TAKE PROFIT triggered at {current_price} (short position)")
|
||||
self._close_position(current_price, 'take_profit')
|
||||
|
||||
def get_state(self):
|
||||
"""Create state representation for the agent"""
|
||||
if len(self.data) < 30 or len(self.features['price']) == 0:
|
||||
# Return zeros if not enough data
|
||||
return np.zeros(STATE_SIZE)
|
||||
|
||||
# Create a normalized state vector with recent price action and indicators
|
||||
|
||||
# Price features (normalize recent prices by the latest price)
|
||||
latest_price = self.features['price'][-1]
|
||||
price_features = self.features['price'][-10:] / latest_price - 1.0
|
||||
|
||||
# Volume features (normalize by max volume)
|
||||
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
||||
vol_features = self.features['volume'][-5:] / max_vol
|
||||
|
||||
# Technical indicators
|
||||
rsi = self.features['rsi'][-3:] / 100.0 # Scale to 0-1
|
||||
|
||||
# MACD (normalize)
|
||||
macd_vals = self.features['macd'][-3:]
|
||||
macd_signal = self.features['macd_signal'][-3:]
|
||||
macd_hist = self.features['macd_hist'][-3:]
|
||||
macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5)
|
||||
macd_norm = macd_vals / macd_scale
|
||||
macd_signal_norm = macd_signal / macd_scale
|
||||
macd_hist_norm = macd_hist / macd_scale
|
||||
|
||||
# Bollinger position (where is price relative to bands)
|
||||
bb_upper = self.features['bollinger_upper'][-3:]
|
||||
bb_lower = self.features['bollinger_lower'][-3:]
|
||||
bb_mid = self.features['bollinger_mid'][-3:]
|
||||
price = self.features['price'][-3:]
|
||||
|
||||
# Calculate position of price within Bollinger Bands (0 to 1)
|
||||
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
||||
|
||||
# Position info
|
||||
position_info = np.zeros(5)
|
||||
if self.position == 'long':
|
||||
position_info[0] = 1.0 # Position is long
|
||||
position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL %
|
||||
position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss %
|
||||
position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit %
|
||||
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
||||
elif self.position == 'short':
|
||||
position_info[0] = -1.0 # Position is short
|
||||
position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL %
|
||||
position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss %
|
||||
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
||||
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
||||
|
||||
# Combine all features
|
||||
state = np.concatenate([
|
||||
price_features, # 10 values
|
||||
vol_features, # 5 values
|
||||
rsi, # 3 values
|
||||
macd_norm, # 3 values
|
||||
macd_signal_norm, # 3 values
|
||||
macd_hist_norm, # 3 values
|
||||
bb_pos, # 3 values
|
||||
self.features['stoch_k'][-3:] / 100.0, # 3 values
|
||||
self.features['stoch_d'][-3:] / 100.0, # 3 values
|
||||
position_info # 5 values
|
||||
])
|
||||
|
||||
# Replace any NaN values
|
||||
state = np.nan_to_num(state, nan=0.0)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, action):
|
||||
"""Execute trading action and return next state, reward, done"""
|
||||
reward = 0.0
|
||||
done = False
|
||||
|
||||
# Get current price
|
||||
if len(self.features['price']) == 0:
|
||||
return self.get_state(), reward, done
|
||||
|
||||
current_price = self.features['price'][-1]
|
||||
|
||||
# Execute action
|
||||
if action == 0: # Hold
|
||||
pass
|
||||
elif action == 1: # Buy (go long)
|
||||
if self.position is None:
|
||||
reward = self._open_long_position()
|
||||
else:
|
||||
reward = -0.1 # Penalty for invalid action
|
||||
elif action == 2: # Sell (go short)
|
||||
if self.position is None:
|
||||
reward = self._open_short_position()
|
||||
else:
|
||||
reward = -0.1 # Penalty for invalid action
|
||||
elif action == 3: # Close position
|
||||
if self.position is not None:
|
||||
reward = self._close_position(current_price, 'agent_decision')
|
||||
else:
|
||||
reward = -0.1 # Penalty for invalid action
|
||||
|
||||
self.current_step += 1
|
||||
|
||||
# Check if episode should end
|
||||
if self.current_step >= 10000 or self.balance <= 0.1 * self.initial_balance:
|
||||
done = True
|
||||
|
||||
return self.get_state(), reward, done
|
||||
|
||||
def _open_long_position(self):
|
||||
"""Open a long position"""
|
||||
current_price = self.features['price'][-1]
|
||||
|
||||
# Calculate position size (90% of balance for fees)
|
||||
position_value = self.balance * 0.9
|
||||
position_size = position_value * self.leverage / current_price
|
||||
|
||||
# Set stop loss and take profit
|
||||
stop_loss = current_price * (1 - STOP_LOSS_PERCENT / 100)
|
||||
take_profit = current_price * (1 + TAKE_PROFIT_PERCENT / 100)
|
||||
|
||||
if not self.is_demo:
|
||||
try:
|
||||
# Create real order on MEXC
|
||||
order = self.exchange.create_market_buy_order(
|
||||
self.symbol,
|
||||
position_size,
|
||||
params={'leverage': self.leverage}
|
||||
)
|
||||
logger.info(f"Opened long position: {order}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening long position: {e}")
|
||||
return -0.2 # Penalty for failed order
|
||||
|
||||
# Update state
|
||||
self.position = 'long'
|
||||
self.entry_price = current_price
|
||||
self.position_size = position_size
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': 'long',
|
||||
'entry_time': datetime.datetime.now().isoformat(),
|
||||
'entry_price': current_price,
|
||||
'position_size': position_size,
|
||||
'stop_loss': stop_loss,
|
||||
'take_profit': take_profit,
|
||||
'balance_before': self.balance
|
||||
})
|
||||
|
||||
logger.info(f"OPENED LONG at {current_price} | Stop loss: {stop_loss} | Take profit: {take_profit}")
|
||||
|
||||
return 0.1 # Small reward for taking action
|
||||
|
||||
def _open_short_position(self):
|
||||
"""Open a short position"""
|
||||
current_price = self.features['price'][-1]
|
||||
|
||||
# Calculate position size (90% of balance for fees)
|
||||
position_value = self.balance * 0.9
|
||||
position_size = position_value * self.leverage / current_price
|
||||
|
||||
# Set stop loss and take profit
|
||||
stop_loss = current_price * (1 + STOP_LOSS_PERCENT / 100)
|
||||
take_profit = current_price * (1 - TAKE_PROFIT_PERCENT / 100)
|
||||
|
||||
if not self.is_demo:
|
||||
try:
|
||||
# Create real order on MEXC
|
||||
order = self.exchange.create_market_sell_order(
|
||||
self.symbol,
|
||||
position_size,
|
||||
params={'leverage': self.leverage}
|
||||
)
|
||||
logger.info(f"Opened short position: {order}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening short position: {e}")
|
||||
return -0.2 # Penalty for failed order
|
||||
|
||||
# Update state
|
||||
self.position = 'short'
|
||||
self.entry_price = current_price
|
||||
self.position_size = position_size
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': 'short',
|
||||
'entry_time': datetime.datetime.now().isoformat(),
|
||||
'entry_price': current_price,
|
||||
'position_size': position_size,
|
||||
'stop_loss': stop_loss,
|
||||
'take_profit': take_profit,
|
||||
'balance_before': self.balance
|
||||
})
|
||||
|
||||
logger.info(f"OPENED SHORT at {current_price} | Stop loss: {stop_loss} | Take profit: {take_profit}")
|
||||
|
||||
return 0.1 # Small reward for taking action
|
||||
|
||||
def _close_position(self, current_price, reason):
|
||||
"""Close the current position"""
|
||||
if self.position is None:
|
||||
return 0.0
|
||||
|
||||
pnl = 0.0
|
||||
|
||||
if self.position == 'long':
|
||||
pnl = (current_price - self.entry_price) / self.entry_price
|
||||
elif self.position == 'short':
|
||||
pnl = (self.entry_price - current_price) / self.entry_price
|
||||
|
||||
# Apply leverage to PnL
|
||||
pnl = pnl * self.leverage
|
||||
|
||||
# Account for 0.1% trading fee
|
||||
position_value = self.position_size * self.entry_price
|
||||
fee = position_value * 0.001
|
||||
pnl_dollar = position_value * pnl - fee
|
||||
|
||||
new_balance = self.balance + pnl_dollar
|
||||
|
||||
if not self.is_demo:
|
||||
try:
|
||||
# Execute real order
|
||||
if self.position == 'long':
|
||||
order = self.exchange.create_market_sell_order(self.symbol, self.position_size)
|
||||
else:
|
||||
order = self.exchange.create_market_buy_order(self.symbol, self.position_size)
|
||||
logger.info(f"Closed position: {order}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing position: {e}")
|
||||
# Still calculate PnL but add penalty
|
||||
pnl -= 0.001
|
||||
|
||||
# Update trade record
|
||||
last_trade = self.trades[-1]
|
||||
last_trade.update({
|
||||
'exit_time': datetime.datetime.now().isoformat(),
|
||||
'exit_price': current_price,
|
||||
'exit_reason': reason,
|
||||
'pnl_percent': pnl * 100,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'balance_after': new_balance
|
||||
})
|
||||
|
||||
logger.info(f"CLOSED {self.position} at {current_price} | PnL: {pnl*100:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.balance = new_balance
|
||||
self.position = None
|
||||
self.entry_price = 0.0
|
||||
self.position_size = 0.0
|
||||
self.stop_loss = 0.0
|
||||
self.take_profit = 0.0
|
||||
|
||||
# Calculate reward (scaled PnL)
|
||||
reward = pnl * 100 # Scale for better learning signal
|
||||
|
||||
return reward
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment for a new episode"""
|
||||
self.balance = self.initial_balance
|
||||
self.position = None
|
||||
self.entry_price = 0.0
|
||||
self.position_size = 0.0
|
||||
self.stop_loss = 0.0
|
||||
self.take_profit = 0.0
|
||||
self.current_step = 0
|
||||
|
||||
return self.get_state()
|
||||
|
||||
class Agent:
|
||||
def __init__(self, state_size, action_size, device="cuda" if torch.cuda.is_available() else "cpu"):
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.device = device
|
||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||
|
||||
# Q-Networks
|
||||
self.policy_net = DQN(state_size, action_size).to(device)
|
||||
self.target_net = DQN(state_size, action_size).to(device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval()
|
||||
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
self.epsilon = EPSILON_START
|
||||
self.steps_done = 0
|
||||
|
||||
# TensorBoard logging
|
||||
self.writer = SummaryWriter(log_dir='runs/trading_agent')
|
||||
|
||||
def select_action(self, state, training=True):
|
||||
sample = random.random()
|
||||
|
||||
if training:
|
||||
# Epsilon decay
|
||||
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
|
||||
np.exp(-1. * self.steps_done / EPSILON_DECAY)
|
||||
self.steps_done += 1
|
||||
|
||||
if sample > self.epsilon or not training:
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).to(self.device)
|
||||
action_values = self.policy_net(state_tensor)
|
||||
return action_values.max(1)[1].item()
|
||||
else:
|
||||
return random.randrange(self.action_size)
|
||||
|
||||
def learn(self):
|
||||
if len(self.memory) < BATCH_SIZE:
|
||||
return None
|
||||
|
||||
experiences = self.memory.sample(BATCH_SIZE)
|
||||
batch = Experience(*zip(*experiences))
|
||||
|
||||
# Convert to tensors
|
||||
state_batch = torch.FloatTensor(batch.state).to(self.device)
|
||||
action_batch = torch.LongTensor(batch.action).unsqueeze(1).to(self.device)
|
||||
reward_batch = torch.FloatTensor(batch.reward).to(self.device)
|
||||
next_state_batch = torch.FloatTensor(batch.next_state).to(self.device)
|
||||
done_batch = torch.FloatTensor(batch.done).to(self.device)
|
||||
|
||||
# Get Q values for chosen actions
|
||||
q_values = self.policy_net(state_batch).gather(1, action_batch)
|
||||
|
||||
# Double DQN: use policy net to select actions, target net to evaluate
|
||||
with torch.no_grad():
|
||||
# Get actions from policy net
|
||||
next_actions = self.policy_net(next_state_batch).max(1)[1].unsqueeze(1)
|
||||
# Evaluate using target net
|
||||
next_q_values = self.target_net(next_state_batch).gather(1, next_actions)
|
||||
next_q_values = next_q_values.squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
expected_q_values = reward_batch + (GAMMA * next_q_values * (1 - done_batch))
|
||||
expected_q_values = expected_q_values.unsqueeze(1)
|
||||
|
||||
# Compute loss (Huber loss for stability)
|
||||
loss = F.smooth_l1_loss(q_values, expected_q_values)
|
||||
|
||||
# Optimize the model
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Gradient clipping
|
||||
for param in self.policy_net.parameters():
|
||||
param.grad.data.clamp_(-1, 1)
|
||||
self.optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
def update_target_network(self):
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
def save(self, path="models/trading_agent.pt"):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'policy_net': self.policy_net.state_dict(),
|
||||
'target_net': self.target_net.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'steps_done': self.steps_done
|
||||
}, path)
|
||||
logger.info(f"Model saved to {path}")
|
||||
|
||||
def load(self, path="models/trading_agent.pt"):
|
||||
if os.path.isfile(path):
|
||||
checkpoint = torch.load(path)
|
||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
self.epsilon = checkpoint['epsilon']
|
||||
self.steps_done = checkpoint['steps_done']
|
||||
logger.info(f"Model loaded from {path}")
|
||||
return True
|
||||
logger.warning(f"No model found at {path}")
|
||||
return False
|
||||
|
||||
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||
"""Get live price data using websockets"""
|
||||
# Connect to MEXC websocket
|
||||
uri = "wss://stream.mexc.com/ws"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Subscribe to kline data
|
||||
subscribe_msg = {
|
||||
"method": "SUBSCRIPTION",
|
||||
"params": [f"spot@public.kline.v3.api@{symbol.replace('/', '').lower()}@{timeframe}"]
|
||||
}
|
||||
await websocket.send(json.dumps(subscribe_msg))
|
||||
|
||||
logger.info(f"Connected to MEXC websocket, subscribed to {symbol} {timeframe} klines")
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await websocket.recv()
|
||||
data = json.loads(response)
|
||||
|
||||
if 'data' in data:
|
||||
kline = data['data']
|
||||
candle = {
|
||||
'timestamp': kline['t'],
|
||||
'open': float(kline['o']),
|
||||
'high': float(kline['h']),
|
||||
'low': float(kline['l']),
|
||||
'close': float(kline['c']),
|
||||
'volume': float(kline['v'])
|
||||
}
|
||||
yield candle
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Websocket error: {e}")
|
||||
# Try to reconnect
|
||||
await asyncio.sleep(5)
|
||||
break
|
||||
|
||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
|
||||
"""Train the agent using historical and live data"""
|
||||
logger.info("Starting training...")
|
||||
|
||||
stats = {
|
||||
'episode_rewards': [],
|
||||
'episode_lengths': [],
|
||||
'balances': [],
|
||||
'win_rates': []
|
||||
}
|
||||
|
||||
best_reward = -float('inf')
|
||||
|
||||
for episode in range(num_episodes):
|
||||
state = env.reset()
|
||||
episode_reward = 0
|
||||
|
||||
for step in range(max_steps_per_episode):
|
||||
# Select action
|
||||
action = agent.select_action(state)
|
||||
|
||||
# Take action
|
||||
next_state, reward, done = env.step(action)
|
||||
|
||||
# Store experience
|
||||
agent.memory.push(state, action, reward, next_state, done)
|
||||
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
|
||||
# Learn from experience
|
||||
loss = agent.learn()
|
||||
if loss is not None:
|
||||
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Update target network
|
||||
if episode % TARGET_UPDATE == 0:
|
||||
agent.update_target_network()
|
||||
|
||||
# Calculate win rate
|
||||
if len(env.trades) > 0:
|
||||
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
||||
win_rate = wins / len(env.trades) * 100
|
||||
else:
|
||||
win_rate = 0
|
||||
|
||||
# Log statistics
|
||||
stats['episode_rewards'].append(episode_reward)
|
||||
stats['episode_lengths'].append(step + 1)
|
||||
stats['balances'].append(env.balance)
|
||||
stats['win_rates'].append(win_rate)
|
||||
|
||||
# Log to TensorBoard
|
||||
agent.writer.add_scalar('Reward/train', episode_reward, episode)
|
||||
agent.writer.add_scalar('Balance/train', env.balance, episode)
|
||||
agent.writer.add_scalar('WinRate/train', win_rate, episode)
|
||||
|
||||
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
||||
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}")
|
||||
|
||||
# Save best model
|
||||
if episode_reward > best_reward:
|
||||
best_reward = episode_reward
|
||||
agent.save("models/trading_agent_best.pt")
|
||||
|
||||
# Save checkpoint
|
||||
if episode % 10 == 0:
|
||||
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
||||
|
||||
# Save final model
|
||||
agent.save("models/trading_agent_final.pt")
|
||||
|
||||
# Plot training results
|
||||
plot_training_results(stats)
|
||||
|
||||
return stats
|
||||
|
||||
def plot_training_results(stats):
|
||||
"""Plot training statistics"""
|
||||
plt.figure(figsize=(15, 10))
|
||||
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.plot(stats['episode_rewards'])
|
||||
plt.title('Episode Rewards')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Reward')
|
||||
|
||||
plt.subplot(2, 2, 2)
|
||||
plt.plot(stats['balances'])
|
||||
plt.title('Account Balance')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Balance ($)')
|
||||
|
||||
plt.subplot(2, 2, 3)
|
||||
plt.plot(stats['episode_lengths'])
|
||||
plt.title('Episode Length')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Steps')
|
||||
|
||||
plt.subplot(2, 2, 4)
|
||||
plt.plot(stats['win_rates'])
|
||||
plt.title('Win Rate')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Win Rate (%)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_results.png')
|
||||
plt.close()
|
||||
|
||||
def evaluate_agent(agent, env, num_episodes=10):
|
||||
"""Evaluate the agent on test data"""
|
||||
total_reward = 0
|
||||
total_profit = 0
|
||||
total_trades = 0
|
||||
winning_trades = 0
|
||||
|
||||
for episode in range(num_episodes):
|
||||
state = env.reset()
|
||||
episode_reward = 0
|
||||
initial_balance = env.balance
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
# Select action (no exploration)
|
||||
action = agent.select_action(state, training=False)
|
||||
next_state, reward, done = env.step(action)
|
||||
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
|
||||
total_reward += episode_reward
|
||||
total_profit += env.balance - initial_balance
|
||||
|
||||
# Count trades and wins
|
||||
for trade in env.trades:
|
||||
if 'pnl_percent' in trade:
|
||||
total_trades += 1
|
||||
if trade['pnl_percent'] > 0:
|
||||
winning_trades += 1
|
||||
|
||||
# Calculate averages
|
||||
avg_reward = total_reward / num_episodes
|
||||
avg_profit = total_profit / num_episodes
|
||||
win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
|
||||
|
||||
logger.info(f"Evaluation results: Avg Reward={avg_reward:.2f}, Avg Profit=${avg_profit:.2f}, "
|
||||
f"Win Rate={win_rate:.1f}%")
|
||||
|
||||
return avg_reward, avg_profit, win_rate
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='ETH/USD Trading Bot with RL')
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval', 'live'],
|
||||
help='Operation mode: train, eval, or live')
|
||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training/evaluation')
|
||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trading)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize exchange
|
||||
exchange = ccxt.mexc({
|
||||
'apiKey': MEXC_API_KEY,
|
||||
'secret': MEXC_SECRET_KEY,
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
|
||||
# Create environment
|
||||
env = TradingEnvironment(
|
||||
exchange=exchange,
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1m",
|
||||
leverage=MAX_LEVERAGE,
|
||||
initial_balance=INITIAL_BALANCE,
|
||||
is_demo=args.demo or args.mode != 'live' # Only trade for real in live mode
|
||||
)
|
||||
|
||||
# Fetch initial data
|
||||
await env.fetch_initial_data()
|
||||
|
||||
# Create agent
|
||||
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
|
||||
|
||||
# Try to load existing model
|
||||
model_loaded = agent.load()
|
||||
if not model_loaded and args.mode in ['eval', 'live']:
|
||||
logger.warning("No pre-trained model found. Consider training first!")
|
||||
|
||||
if args.mode == 'train':
|
||||
# Training mode
|
||||
logger.info("Starting training mode")
|
||||
await train_agent(agent, env, num_episodes=args.episodes)
|
||||
|
||||
elif args.mode == 'eval':
|
||||
# Evaluation mode
|
||||
logger.info("Starting evaluation mode")
|
||||
eval_reward, eval_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
|
||||
|
||||
elif args.mode == 'live':
|
||||
# Live trading mode
|
||||
logger.info("Starting live trading mode with real-time data")
|
||||
logger.info(f"Demo mode: {args.demo}")
|
||||
|
||||
# Live trading loop
|
||||
async for candle in get_live_prices("ETH/USDT", "1m"):
|
||||
# Update environment with new data
|
||||
await env._update_with_new_data(candle)
|
||||
|
||||
# Only trade if we have enough data
|
||||
if len(env.data) >= env.window_size:
|
||||
# Get current state
|
||||
state = env.get_state()
|
||||
|
||||
# Select action (no exploration in live trading)
|
||||
action = agent.select_action(state, training=False)
|
||||
|
||||
# Convert action number to readable format
|
||||
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
|
||||
logger.info(f"Price: ${candle['close']:.2f} | Action: {action_names[action]}")
|
||||
|
||||
# Take action
|
||||
_, reward, _ = env.step(action)
|
||||
|
||||
# Print statistics
|
||||
if len(env.trades) > 0:
|
||||
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
||||
win_rate = wins / len(env.trades) * 100
|
||||
total_pnl = sum(trade.get('pnl_dollar', 0) for trade in env.trades)
|
||||
logger.info(f"Balance: ${env.balance:.2f} | Trades: {len(env.trades)} | "
|
||||
f"Win Rate: {win_rate:.1f}% | Total PnL: ${total_pnl:.2f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
156
crypto/gogo2/readme.md
Normal file
156
crypto/gogo2/readme.md
Normal file
@ -0,0 +1,156 @@
|
||||
# Crypto Trading Bot with Reinforcement Learning
|
||||
|
||||
An automated cryptocurrency trading bot that uses Deep Q-Learning (DQN) to trade ETH/USDT on the MEXC exchange. The bot features a sophisticated neural network architecture with LSTM layers and attention mechanisms for better pattern recognition.
|
||||
|
||||
## Features
|
||||
|
||||
- Deep Q-Learning with experience replay
|
||||
- LSTM layers for sequential data processing
|
||||
- Multi-head attention mechanism
|
||||
- Dueling DQN architecture
|
||||
- Real-time trading capabilities
|
||||
- TensorBoard integration for monitoring
|
||||
- Comprehensive technical indicators
|
||||
- Demo and live trading modes
|
||||
- Automatic model checkpointing
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.8+
|
||||
- MEXC Exchange API credentials
|
||||
- GPU recommended but not required
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yourusername/crypto-trading-bot.git
|
||||
cd crypto-trading-bot
|
||||
```
|
||||
2. Create a virtual environment:
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
3. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
4. Create a `.env` file in the project root with your MEXC API credentials:
|
||||
|
||||
```bash
|
||||
MEXC_API_KEY=your_api_key
|
||||
MEXC_API_SECRET=your_api_secret
|
||||
```
|
||||
## Usage
|
||||
|
||||
The bot can be run in three modes:
|
||||
|
||||
### Training Mode
|
||||
|
||||
```bash
|
||||
python main.py --mode train --episodes 1000
|
||||
```
|
||||
|
||||
### Evaluation Mode
|
||||
|
||||
```bash
|
||||
python main.py --mode eval --episodes 10
|
||||
```
|
||||
|
||||
### Live Trading Mode
|
||||
|
||||
```bash
|
||||
# Demo mode (no real trades)
|
||||
python main.py --mode live --demo
|
||||
# Real trading
|
||||
python main.py --mode live
|
||||
```
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters can be adjusted in `main.py`:
|
||||
|
||||
- `INITIAL_BALANCE`: Starting balance for training/demo
|
||||
- `MAX_LEVERAGE`: Maximum leverage for trades
|
||||
- `STOP_LOSS_PERCENT`: Stop loss percentage
|
||||
- `TAKE_PROFIT_PERCENT`: Take profit percentage
|
||||
- `BATCH_SIZE`: Training batch size
|
||||
- `LEARNING_RATE`: Model learning rate
|
||||
- `STATE_SIZE`: Size of the state representation
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The DQN model includes:
|
||||
- Input layer with technical indicators
|
||||
- LSTM layers for temporal pattern recognition
|
||||
- Multi-head attention mechanism
|
||||
- Dueling architecture for better Q-value estimation
|
||||
- Batch normalization for stable training
|
||||
|
||||
## Monitoring
|
||||
|
||||
Training progress can be monitored using TensorBoard:
|
||||
|
||||
|
||||
Training progress is logged to TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir=logs
|
||||
```
|
||||
|
||||
This will show:
|
||||
- Training rewards
|
||||
- Account balance
|
||||
- Win rate
|
||||
- Loss metrics
|
||||
|
||||
## Trading Strategy
|
||||
|
||||
The bot makes decisions based on:
|
||||
- Price action
|
||||
- Technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Historical patterns through LSTM
|
||||
- Risk management with stop-loss and take-profit
|
||||
|
||||
## Safety Features
|
||||
|
||||
- Demo mode for safe testing
|
||||
- Automatic stop-loss
|
||||
- Position size limits
|
||||
- Error handling for API calls
|
||||
- Logging of all actions
|
||||
|
||||
## Directory Structure
|
||||
├── main.py # Main bot implementation
|
||||
├── requirements.txt # Project dependencies
|
||||
├── .env # API credentials
|
||||
├── models/ # Saved model checkpoints
|
||||
├── runs/ # TensorBoard logs
|
||||
└── trading_bot.log # Activity logs
|
||||
|
||||
|
||||
## Warning
|
||||
|
||||
Cryptocurrency trading carries significant risks. This bot is for educational purposes and should not be used with real money without thorough testing and understanding of the risks involved.
|
||||
|
||||
## License
|
||||
|
||||
[MIT License](LICENSE)
|
||||
|
||||
The main changes I made:
|
||||
Fixed code block formatting by adding proper language identifiers
|
||||
Added missing closing code blocks
|
||||
Properly formatted directory structure
|
||||
Added complete sections that were cut off in the original
|
||||
Ensured consistent formatting throughout the document
|
||||
Added proper bash syntax highlighting for command examples
|
||||
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
|
||||
|
||||
|
8
crypto/gogo2/requirements.txt
Normal file
8
crypto/gogo2/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
numpy>=1.21.0
|
||||
pandas>=1.3.0
|
||||
matplotlib>=3.4.0
|
||||
torch>=1.9.0
|
||||
python-dotenv>=0.19.0
|
||||
ccxt>=2.0.0
|
||||
websockets>=10.0
|
||||
tensorboard>=2.6.0
|
Loading…
x
Reference in New Issue
Block a user