fixes, lots of new ideas
This commit is contained in:
parent
d9f1bac11c
commit
ad559d8c61
4
.gitignore
vendored
4
.gitignore
vendored
@ -38,4 +38,6 @@ crypto/gogo2/trading_bot.log
|
|||||||
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||||
*trading_agent_continuous_*.pt
|
*trading_agent_continuous_*.pt
|
||||||
*trading_agent_episode_*.pt
|
*trading_agent_episode_*.pt
|
||||||
crypto/gogo2/models/trading_agent_continuous_150.pt
|
crypto/gogo2/models/trading_agent_continuous_*.pt
|
||||||
|
crypto/gogo2/visualizations/training_episode_*.png
|
||||||
|
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||||
|
67
crypto/gogo2/_model.md
Normal file
67
crypto/gogo2/_model.md
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# Neural Network Architecture Analysis for Trading Bot
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document provides a comprehensive analysis of the neural network architecture used in our trading bot system. The system consists of two main neural network components:
|
||||||
|
|
||||||
|
1. **Price Prediction Model** - Forecasts future price movements and extrema points
|
||||||
|
2. **DQN (Deep Q-Network)** - Makes trading decisions based on state representations
|
||||||
|
|
||||||
|
## 1. Price Prediction Model
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
PricePredictionModel(nn.Module)
|
||||||
|
├── Input Layer: [batch_size, seq_len, 2] (price, volume)
|
||||||
|
├── LSTM Layers: 2 stacked layers with hidden_size=128
|
||||||
|
├── Attention Mechanism: Self-attention with linear projections
|
||||||
|
├── Linear Layer 1: hidden_size → hidden_size
|
||||||
|
├── ReLU Activation
|
||||||
|
├── Linear Layer 2: hidden_size → output_size (5 future prices)
|
||||||
|
└── Output: [batch_size, output_size]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
|
||||||
|
**Inputs:**
|
||||||
|
- `price_history`: Sequence of historical prices [batch_size, seq_len]
|
||||||
|
- `volume_history`: Sequence of historical volumes [batch_size, seq_len]
|
||||||
|
|
||||||
|
**Preprocessing:**
|
||||||
|
- Normalization using MinMaxScaler (0-1 range)
|
||||||
|
- Reshaping to [batch_size, seq_len, 2] (price and volume features)
|
||||||
|
|
||||||
|
**Forward Pass:**
|
||||||
|
1. Input data passes through LSTM layers
|
||||||
|
2. Self-attention mechanism applied to LSTM outputs
|
||||||
|
3. Linear layers process the attended features
|
||||||
|
4. Output represents predicted prices for next 5 candles
|
||||||
|
|
||||||
|
**Outputs:**
|
||||||
|
- `predicted_prices`: Array of 5 future price predictions
|
||||||
|
- `predicted_extrema`: Binary indicators for potential price extrema points
|
||||||
|
|
||||||
|
## 2. DQN (Deep Q-Network)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
DQN(nn.Module)
|
||||||
|
├── Input Layer: [batch_size, state_size]
|
||||||
|
├── Linear Layer 1: state_size → hidden_size (384)
|
||||||
|
├── ReLU Activation
|
||||||
|
├── LSTM Layers: 2 stacked layers with hidden_size=384
|
||||||
|
├── Multi-Head Attention: 4 attention heads
|
||||||
|
├── Linear Layer 2: hidden_size → hidden_size
|
||||||
|
├── ReLU Activation
|
||||||
|
├── Linear Layer 3: hidden_size → action_size (4)
|
||||||
|
└── Output: [batch_size, action_size] (Q-values for each action)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
|
||||||
|
**Inputs:**
|
||||||
|
- `state`: Current market state representation [batch_size, state_size]
|
||||||
|
- Price features (normalized prices, returns, volatility)
|
||||||
|
- Technical indicators (RSI, MACD, Stochastic,
|
@ -4,6 +4,12 @@ ensure we use GPU if available to train faster. during training we need to have
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
do we call the trading api proerly when in live mode? in live mode - also validate the current ballance. also ensure trades are executed after executing the orders by checking the open orders.
|
||||||
|
|
||||||
|
our trading data chart (in tensorboard) does not properly displayed - the candles seems displayed multiple times but shifted in time. we also do not correctly show the buy/sell evens on the time axis. we do not show the predicted price on the chart.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
|
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
|
||||||
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
|
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
|
||||||
@ -22,3 +28,13 @@ Backend tkagg is interactive backend. Turning interactive mode on.
|
|||||||
|
|
||||||
|
|
||||||
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
||||||
|
|
||||||
|
|
||||||
|
----------------
|
||||||
|
|
||||||
|
do we train the price prediction model by using old known candles and masking the latest to make him guess the next and then backpropagating the next already known candle ? like a transformer (gpt2) would do? or we use RL for that as well?
|
||||||
|
it seems the model is not learning a lot. we keep hovering about the same starting balance even after some time in training in continious mode
|
||||||
|
|
||||||
|
|
||||||
|
it seems we may need another NN model down the loop jut to predict the extremums of the price.
|
||||||
|
we may have to include a mechanism to calculate the extremums of the price retrospectively and to use that to bootstrap pre-train the model.
|
||||||
|
362
crypto/gogo2/archive.py
Normal file
362
crypto/gogo2/archive.py
Normal file
@ -0,0 +1,362 @@
|
|||||||
|
class MexcTradingClient:
|
||||||
|
def __init__(self, api_key, secret_key, symbol="ETH/USDT", leverage=50):
|
||||||
|
self.client = ccxt.mexc({
|
||||||
|
'apiKey': api_key,
|
||||||
|
'secret': secret_key,
|
||||||
|
'enableRateLimit': True,
|
||||||
|
})
|
||||||
|
self.symbol = symbol
|
||||||
|
self.leverage = leverage
|
||||||
|
self.position = 'flat'
|
||||||
|
self.position_size = 0
|
||||||
|
self.entry_price = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
self.trades = []
|
||||||
|
|
||||||
|
def initialize_mexc_client(self, api_key, api_secret):
|
||||||
|
"""Initialize the MEXC API client"""
|
||||||
|
try:
|
||||||
|
from mexc_sdk import Spot
|
||||||
|
self.mexc_client = Spot(api_key=api_key, api_secret=api_secret)
|
||||||
|
# Test connection
|
||||||
|
self.mexc_client.ping()
|
||||||
|
logger.info("MEXC API client initialized successfully")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize MEXC API client: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fetch_account_balance(self):
|
||||||
|
"""Fetch actual account balance from MEXC API"""
|
||||||
|
if self.demo or not self.mexc_client:
|
||||||
|
# In demo mode, use simulated balance
|
||||||
|
return self.balance
|
||||||
|
|
||||||
|
try:
|
||||||
|
account_info = self.mexc_client.accountInfo()
|
||||||
|
if 'balances' in account_info:
|
||||||
|
# Find USDT balance
|
||||||
|
for asset in account_info['balances']:
|
||||||
|
if asset['asset'] == 'USDT':
|
||||||
|
return float(asset['free'])
|
||||||
|
|
||||||
|
logger.warning("Could not find USDT balance, using current simulated balance")
|
||||||
|
return self.balance
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching account balance: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
# Fallback to simulated balance in case of API error
|
||||||
|
return self.balance
|
||||||
|
|
||||||
|
async def fetch_open_positions(self):
|
||||||
|
"""Fetch actual open positions from MEXC API"""
|
||||||
|
if self.demo or not self.mexc_client:
|
||||||
|
# In demo mode, return current simulated position
|
||||||
|
return [{
|
||||||
|
'symbol': 'ETH/USDT',
|
||||||
|
'positionSide': 'LONG' if self.position == 'long' else 'SHORT' if self.position == 'short' else 'NONE',
|
||||||
|
'positionAmt': self.position_size / self.current_price if self.position != 'flat' else 0,
|
||||||
|
'entryPrice': self.entry_price,
|
||||||
|
'unrealizedProfit': self.calculate_unrealized_pnl()
|
||||||
|
}] if self.position != 'flat' else []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch open positions
|
||||||
|
positions = self.mexc_client.openOrders('ETH/USDT')
|
||||||
|
return positions
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching open positions: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
# Fallback to simulated positions in case of API error
|
||||||
|
return []
|
||||||
|
|
||||||
|
def calculate_unrealized_pnl(self):
|
||||||
|
"""Calculate unrealized PnL for the current position"""
|
||||||
|
if self.position == 'flat':
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
position_value = self.position_size / self.entry_price
|
||||||
|
|
||||||
|
if self.position == 'long':
|
||||||
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
||||||
|
else: # short
|
||||||
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||||
|
|
||||||
|
# Apply leverage
|
||||||
|
pnl_percent *= self.leverage
|
||||||
|
|
||||||
|
return position_value * pnl_percent / 100
|
||||||
|
|
||||||
|
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
|
||||||
|
"""Open a new position using MEXC API in live mode, or simulate in demo mode"""
|
||||||
|
if self.demo or not self.mexc_client:
|
||||||
|
# In demo mode, simulate opening a position
|
||||||
|
self.position = position_type
|
||||||
|
self.position_size = size
|
||||||
|
self.entry_price = entry_price
|
||||||
|
self.entry_index = self.current_step
|
||||||
|
self.stop_loss = stop_loss
|
||||||
|
self.take_profit = take_profit
|
||||||
|
|
||||||
|
logger.info(f"DEMO: Opened {position_type.upper()} position at {entry_price} | " +
|
||||||
|
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# In live mode, place actual orders via API
|
||||||
|
symbol = "ETHUSDT" # Format required by MEXC
|
||||||
|
side = "BUY" if position_type == 'long' else "SELL"
|
||||||
|
|
||||||
|
# Calculate quantity based on size and price
|
||||||
|
quantity = size / entry_price
|
||||||
|
|
||||||
|
# Place main order
|
||||||
|
order_result = self.mexc_client.newOrder(
|
||||||
|
symbol=symbol,
|
||||||
|
side=side,
|
||||||
|
orderType="MARKET",
|
||||||
|
quantity=quantity,
|
||||||
|
options={
|
||||||
|
"leverage": self.leverage,
|
||||||
|
"newOrderRespType": "FULL"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if order executed
|
||||||
|
if order_result.get('status') == 'FILLED':
|
||||||
|
actual_entry_price = float(order_result.get('price', entry_price))
|
||||||
|
|
||||||
|
# Place stop loss order
|
||||||
|
sl_order = self.mexc_client.newOrder(
|
||||||
|
symbol=symbol,
|
||||||
|
side="SELL" if position_type == 'long' else "BUY",
|
||||||
|
orderType="STOP_LOSS",
|
||||||
|
quantity=quantity,
|
||||||
|
options={
|
||||||
|
"stopPrice": stop_loss,
|
||||||
|
"newOrderRespType": "ACK"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Place take profit order
|
||||||
|
tp_order = self.mexc_client.newOrder(
|
||||||
|
symbol=symbol,
|
||||||
|
side="SELL" if position_type == 'long' else "BUY",
|
||||||
|
orderType="TAKE_PROFIT",
|
||||||
|
quantity=quantity,
|
||||||
|
options={
|
||||||
|
"stopPrice": take_profit,
|
||||||
|
"newOrderRespType": "ACK"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update local state
|
||||||
|
self.position = position_type
|
||||||
|
self.position_size = size
|
||||||
|
self.entry_price = actual_entry_price
|
||||||
|
self.entry_index = self.current_step
|
||||||
|
self.stop_loss = stop_loss
|
||||||
|
self.take_profit = take_profit
|
||||||
|
|
||||||
|
# Track orders
|
||||||
|
self.open_orders.extend([sl_order, tp_order])
|
||||||
|
self.order_history.append(order_result)
|
||||||
|
|
||||||
|
logger.info(f"LIVE: Opened {position_type.upper()} position at {actual_entry_price} | " +
|
||||||
|
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to execute order: {order_result}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error opening position: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close_position(self, reason="manual_close"):
|
||||||
|
"""Close the current position using MEXC API in live mode, or simulate in demo mode"""
|
||||||
|
if self.position == 'flat':
|
||||||
|
logger.info("No position to close")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.demo or not self.mexc_client:
|
||||||
|
# In demo mode, simulate closing a position
|
||||||
|
position_type = self.position
|
||||||
|
entry_price = self.entry_price
|
||||||
|
exit_price = self.current_price
|
||||||
|
position_size = self.position_size
|
||||||
|
|
||||||
|
# Calculate PnL
|
||||||
|
if position_type == 'long':
|
||||||
|
pnl_percent = (exit_price - entry_price) / entry_price * 100
|
||||||
|
else: # short
|
||||||
|
pnl_percent = (entry_price - exit_price) / entry_price * 100
|
||||||
|
|
||||||
|
# Apply leverage
|
||||||
|
pnl_percent *= self.leverage
|
||||||
|
|
||||||
|
# Calculate actual PnL
|
||||||
|
pnl_dollar = position_size * pnl_percent / 100
|
||||||
|
|
||||||
|
# Apply fees
|
||||||
|
pnl_dollar -= self.calculate_fees(position_size)
|
||||||
|
|
||||||
|
# Update balance
|
||||||
|
self.balance += pnl_dollar
|
||||||
|
self.total_pnl += pnl_dollar
|
||||||
|
self.episode_pnl += pnl_dollar
|
||||||
|
|
||||||
|
# Update max drawdown
|
||||||
|
if self.balance > self.peak_balance:
|
||||||
|
self.peak_balance = self.balance
|
||||||
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||||
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
self.trades.append({
|
||||||
|
'type': position_type,
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': exit_price,
|
||||||
|
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||||
|
'exit_time': self.data[self.current_step]['timestamp'],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'duration': self.current_step - self.entry_index,
|
||||||
|
'market_direction': self.get_market_direction(),
|
||||||
|
'reason': reason,
|
||||||
|
'leverage': self.leverage
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
logger.info(f"DEMO: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.entry_index = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# In live mode, close position via API
|
||||||
|
symbol = "ETHUSDT"
|
||||||
|
position_info = await self.fetch_open_positions()
|
||||||
|
|
||||||
|
if not position_info:
|
||||||
|
logger.warning("No open positions found to close")
|
||||||
|
self.position = 'flat'
|
||||||
|
return False
|
||||||
|
|
||||||
|
# First, cancel any existing stop loss/take profit orders
|
||||||
|
try:
|
||||||
|
self.mexc_client.cancelOpenOrders(symbol)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error canceling open orders: {e}")
|
||||||
|
|
||||||
|
# Close the position with a market order
|
||||||
|
position_type = self.position
|
||||||
|
side = "SELL" if position_type == 'long' else "BUY"
|
||||||
|
quantity = self.position_size / self.current_price
|
||||||
|
|
||||||
|
# Execute order
|
||||||
|
order_result = self.mexc_client.newOrder(
|
||||||
|
symbol=symbol,
|
||||||
|
side=side,
|
||||||
|
orderType="MARKET",
|
||||||
|
quantity=quantity,
|
||||||
|
options={
|
||||||
|
"newOrderRespType": "FULL"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if order executed
|
||||||
|
if order_result.get('status') == 'FILLED':
|
||||||
|
exit_price = float(order_result.get('price', self.current_price))
|
||||||
|
entry_price = self.entry_price
|
||||||
|
position_size = self.position_size
|
||||||
|
|
||||||
|
# Calculate PnL
|
||||||
|
if position_type == 'long':
|
||||||
|
pnl_percent = (exit_price - entry_price) / entry_price * 100
|
||||||
|
else: # short
|
||||||
|
pnl_percent = (entry_price - exit_price) / entry_price * 100
|
||||||
|
|
||||||
|
# Apply leverage
|
||||||
|
pnl_percent *= self.leverage
|
||||||
|
|
||||||
|
# Calculate actual PnL
|
||||||
|
pnl_dollar = position_size * pnl_percent / 100
|
||||||
|
|
||||||
|
# Apply fees
|
||||||
|
pnl_dollar -= self.calculate_fees(position_size)
|
||||||
|
|
||||||
|
# Update balance from API
|
||||||
|
self.balance = await self.fetch_account_balance()
|
||||||
|
self.total_pnl += pnl_dollar
|
||||||
|
self.episode_pnl += pnl_dollar
|
||||||
|
|
||||||
|
# Update max drawdown
|
||||||
|
if self.balance > self.peak_balance:
|
||||||
|
self.peak_balance = self.balance
|
||||||
|
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||||
|
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
self.trades.append({
|
||||||
|
'type': position_type,
|
||||||
|
'entry': entry_price,
|
||||||
|
'exit': exit_price,
|
||||||
|
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||||
|
'exit_time': self.data[self.current_step]['timestamp'],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'duration': self.current_step - self.entry_index,
|
||||||
|
'market_direction': self.get_market_direction(),
|
||||||
|
'reason': reason,
|
||||||
|
'leverage': self.leverage,
|
||||||
|
'order_id': order_result.get('orderId')
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
# Track order history
|
||||||
|
self.order_history.append(order_result)
|
||||||
|
|
||||||
|
logger.info(f"LIVE: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.entry_index = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to close position: {order_result}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing position: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1 +1 @@
|
|||||||
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 30, "timestamp": "2025-03-10T17:57:19.913481"}
|
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 0, "timestamp": "2025-03-12T00:23:19.125190"}
|
207
crypto/gogo2/count_params.py
Normal file
207
crypto/gogo2/count_params.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
|
class PricePredictionModel(nn.Module):
|
||||||
|
def __init__(self, input_dim=2, hidden_dim=128, num_layers=2, output_dim=5):
|
||||||
|
super(PricePredictionModel, self).__init__()
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# LSTM layers
|
||||||
|
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
|
||||||
|
|
||||||
|
# Self-attention mechanism
|
||||||
|
self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
|
||||||
|
|
||||||
|
# Fully connected layer for price prediction
|
||||||
|
self.price_fc = nn.Linear(hidden_dim, output_dim)
|
||||||
|
|
||||||
|
# Fully connected layer for extrema prediction (high and low points)
|
||||||
|
self.extrema_fc = nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x shape: (batch_size, seq_len, input_dim)
|
||||||
|
|
||||||
|
# LSTM forward pass
|
||||||
|
lstm_out, _ = self.lstm(x) # lstm_out: (batch_size, seq_len, hidden_dim)
|
||||||
|
|
||||||
|
# Self-attention
|
||||||
|
attn_output, _ = self.attention(lstm_out, lstm_out, lstm_out)
|
||||||
|
|
||||||
|
# Price prediction
|
||||||
|
price_pred = self.price_fc(attn_output[:, -1, :]) # Use the last time step
|
||||||
|
|
||||||
|
# Extrema prediction
|
||||||
|
extrema_logits = self.extrema_fc(attn_output[:, -1, :])
|
||||||
|
|
||||||
|
return price_pred, extrema_logits
|
||||||
|
|
||||||
|
class DQN(nn.Module):
|
||||||
|
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
||||||
|
super(DQN, self).__init__()
|
||||||
|
|
||||||
|
# Feature extraction layers
|
||||||
|
self.feature_extraction = nn.Sequential(
|
||||||
|
nn.Linear(state_dim, hidden_dim),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advantage stream
|
||||||
|
self.advantage_stream = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(hidden_dim, action_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Value stream
|
||||||
|
self.value_stream = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(hidden_dim, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer for temporal dependencies
|
||||||
|
encoder_layers = TransformerEncoderLayer(d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
|
||||||
|
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||||
|
|
||||||
|
# LSTM for sequential decision making
|
||||||
|
self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
|
||||||
|
|
||||||
|
# Final layers
|
||||||
|
self.final_layers = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim*2, hidden_dim),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(hidden_dim, action_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, state, hidden=None):
|
||||||
|
# Extract features
|
||||||
|
features = self.feature_extraction(state)
|
||||||
|
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
|
||||||
|
|
||||||
|
# Transformer processing
|
||||||
|
transformer_out = self.transformer(features)
|
||||||
|
|
||||||
|
# LSTM processing
|
||||||
|
lstm_out, lstm_hidden = self.lstm(transformer_out)
|
||||||
|
|
||||||
|
# Dueling architecture
|
||||||
|
advantage = self.advantage_stream(features.squeeze(1))
|
||||||
|
value = self.value_stream(features.squeeze(1))
|
||||||
|
|
||||||
|
# Combine transformer, LSTM and dueling outputs
|
||||||
|
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
|
||||||
|
q_values = self.final_layers(combined)
|
||||||
|
|
||||||
|
# Dueling Q-value computation
|
||||||
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
return q_values, lstm_hidden
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
total_params = 0
|
||||||
|
layer_params = {}
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param_count = param.numel()
|
||||||
|
total_params += param_count
|
||||||
|
layer_params[name] = (param_count, param.shape)
|
||||||
|
|
||||||
|
return total_params, layer_params
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize the Price Prediction Model
|
||||||
|
price_model = PricePredictionModel()
|
||||||
|
price_total_params, price_layer_params = count_parameters(price_model)
|
||||||
|
|
||||||
|
print(f"Price Prediction Model parameters: {price_total_params:,}")
|
||||||
|
print("\nPrice Prediction Model Layers:")
|
||||||
|
for name, (count, shape) in price_layer_params.items():
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
|
||||||
|
# Initialize the DQN Model with typical dimensions
|
||||||
|
state_dim = 50 # Typical state dimension for the trading bot
|
||||||
|
action_dim = 3 # Typical action dimension (buy, sell, hold)
|
||||||
|
dqn_model = DQN(state_dim=state_dim, action_dim=action_dim)
|
||||||
|
dqn_total_params, dqn_layer_params = count_parameters(dqn_model)
|
||||||
|
|
||||||
|
# Count parameters by category
|
||||||
|
feature_extraction_params = sum(count for name, (count, _) in dqn_layer_params.items() if "feature_extraction" in name)
|
||||||
|
advantage_value_params = sum(count for name, (count, _) in dqn_layer_params.items() if "advantage_stream" in name or "value_stream" in name)
|
||||||
|
transformer_params = sum(count for name, (count, _) in dqn_layer_params.items() if "transformer" in name)
|
||||||
|
lstm_params = sum(count for name, (count, _) in dqn_layer_params.items() if "lstm" in name and "transformer" not in name)
|
||||||
|
final_layers_params = sum(count for name, (count, _) in dqn_layer_params.items() if "final_layers" in name)
|
||||||
|
|
||||||
|
print(f"\nDQN Model parameters: {dqn_total_params:,}")
|
||||||
|
|
||||||
|
# Create sets to track which parameters we've printed
|
||||||
|
printed_params = set()
|
||||||
|
|
||||||
|
# Print DQN layers in groups to avoid output truncation
|
||||||
|
print(f"\nDQN Model Layers (Feature Extraction): {feature_extraction_params:,} parameters")
|
||||||
|
for name, (count, shape) in dqn_layer_params.items():
|
||||||
|
if "feature_extraction" in name:
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
printed_params.add(name)
|
||||||
|
|
||||||
|
print(f"\nDQN Model Layers (Advantage & Value Streams): {advantage_value_params:,} parameters")
|
||||||
|
for name, (count, shape) in dqn_layer_params.items():
|
||||||
|
if "advantage_stream" in name or "value_stream" in name:
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
printed_params.add(name)
|
||||||
|
|
||||||
|
print(f"\nDQN Model Layers (Transformer): {transformer_params:,} parameters")
|
||||||
|
for name, (count, shape) in dqn_layer_params.items():
|
||||||
|
if "transformer" in name:
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
printed_params.add(name)
|
||||||
|
|
||||||
|
print(f"\nDQN Model Layers (LSTM): {lstm_params:,} parameters")
|
||||||
|
for name, (count, shape) in dqn_layer_params.items():
|
||||||
|
if "lstm" in name and "transformer" not in name:
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
printed_params.add(name)
|
||||||
|
|
||||||
|
print(f"\nDQN Model Layers (Final Layers): {final_layers_params:,} parameters")
|
||||||
|
for name, (count, shape) in dqn_layer_params.items():
|
||||||
|
if "final_layers" in name:
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
printed_params.add(name)
|
||||||
|
|
||||||
|
# Print any remaining parameters that weren't caught by the categories above
|
||||||
|
remaining_params = set(dqn_layer_params.keys()) - printed_params
|
||||||
|
if remaining_params:
|
||||||
|
remaining_params_count = sum(dqn_layer_params[name][0] for name in remaining_params)
|
||||||
|
print(f"\nDQN Model Layers (Other): {remaining_params_count:,} parameters")
|
||||||
|
for name in remaining_params:
|
||||||
|
count, shape = dqn_layer_params[name]
|
||||||
|
print(f"{name}: {count:,} (shape: {shape})")
|
||||||
|
|
||||||
|
# Total parameters across both models
|
||||||
|
print(f"\nTotal parameters (both models): {price_total_params + dqn_total_params:,}")
|
||||||
|
|
||||||
|
# Print summary of parameter distribution
|
||||||
|
print("\nParameter Distribution Summary:")
|
||||||
|
print(f"Price Prediction Model: {price_total_params:,} parameters ({price_total_params/(price_total_params + dqn_total_params)*100:.1f}%)")
|
||||||
|
print(f"DQN Model: {dqn_total_params:,} parameters ({dqn_total_params/(price_total_params + dqn_total_params)*100:.1f}%)")
|
||||||
|
print("\nDQN Model Breakdown:")
|
||||||
|
print(f"- Feature Extraction: {feature_extraction_params:,} parameters ({feature_extraction_params/dqn_total_params*100:.1f}%)")
|
||||||
|
print(f"- Advantage & Value Streams: {advantage_value_params:,} parameters ({advantage_value_params/dqn_total_params*100:.1f}%)")
|
||||||
|
print(f"- Transformer: {transformer_params:,} parameters ({transformer_params/dqn_total_params*100:.1f}%)")
|
||||||
|
print(f"- LSTM: {lstm_params:,} parameters ({lstm_params/dqn_total_params*100:.1f}%)")
|
||||||
|
print(f"- Final Layers: {final_layers_params:,} parameters ({final_layers_params/dqn_total_params*100:.1f}%)")
|
||||||
|
|
||||||
|
# Verify that all parameters are accounted for
|
||||||
|
total_by_category = feature_extraction_params + advantage_value_params + transformer_params + lstm_params + final_layers_params
|
||||||
|
if remaining_params:
|
||||||
|
total_by_category += remaining_params_count
|
||||||
|
print(f"\nSum of all categories: {total_by_category:,} parameters")
|
||||||
|
print(f"Difference from total: {dqn_total_params - total_by_category:,} parameters")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
42
crypto/gogo2/fix_indentation.py
Normal file
42
crypto/gogo2/fix_indentation.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
def fix_indentation():
|
||||||
|
with open('main.py', 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
# Fix indentation for the problematic sections
|
||||||
|
fixed_lines = []
|
||||||
|
|
||||||
|
# Find the try block that starts at line 1693
|
||||||
|
try_start_line = 1693
|
||||||
|
try_block_found = False
|
||||||
|
in_try_block = False
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# Check if we're at the try statement
|
||||||
|
if i+1 == try_start_line and 'try:' in line:
|
||||||
|
try_block_found = True
|
||||||
|
in_try_block = True
|
||||||
|
fixed_lines.append(line)
|
||||||
|
# Fix the indentation of the experiences line
|
||||||
|
elif i+1 == 1695 and line.strip().startswith('experiences = self.memory.sample(BATCH_SIZE)'):
|
||||||
|
# Add proper indentation (4 spaces)
|
||||||
|
fixed_lines.append(' ' + line.lstrip())
|
||||||
|
# Check if we're at the end of the try block without an except
|
||||||
|
elif try_block_found and in_try_block and i+1 > try_start_line and line.strip() and not line.startswith(' '):
|
||||||
|
# We've reached the end of the try block without an except, add one
|
||||||
|
fixed_lines.append(' except Exception as e:\n')
|
||||||
|
fixed_lines.append(' logger.error(f"Error during learning: {e}")\n')
|
||||||
|
fixed_lines.append(' logger.error(f"Traceback: {traceback.format_exc()}")\n')
|
||||||
|
fixed_lines.append(' return None\n\n')
|
||||||
|
in_try_block = False
|
||||||
|
fixed_lines.append(line)
|
||||||
|
else:
|
||||||
|
fixed_lines.append(line)
|
||||||
|
|
||||||
|
# Write the fixed content back to the file
|
||||||
|
with open('main.py', 'w') as f:
|
||||||
|
f.writelines(fixed_lines)
|
||||||
|
|
||||||
|
print("Indentation fixed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fix_indentation()
|
22
crypto/gogo2/fix_try_blocks.py
Normal file
22
crypto/gogo2/fix_try_blocks.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
def fix_try_blocks():
|
||||||
|
with open('main.py', 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Find all try blocks without except or finally
|
||||||
|
pattern = r'(\s+)try:\s*\n((?:\1\s+.*\n)+?)(?!\1\s*except|\1\s*finally)'
|
||||||
|
|
||||||
|
# Replace with try-except blocks
|
||||||
|
fixed_content = re.sub(pattern,
|
||||||
|
r'\1try:\n\2\1except Exception as e:\n\1 logger.error(f"Error: {e}")\n\1 logger.error(f"Traceback: {traceback.format_exc()}")\n\1 return None\n\n',
|
||||||
|
content)
|
||||||
|
|
||||||
|
# Write the fixed content back to the file
|
||||||
|
with open('main.py', 'w') as f:
|
||||||
|
f.write(fixed_content)
|
||||||
|
|
||||||
|
print("Try blocks fixed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fix_try_blocks()
|
@ -1151,10 +1151,10 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Reward based on PnL
|
# Reward based on PnL
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
else:
|
else:
|
||||||
reward = -1.0 # Negative reward for loss
|
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
|
||||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
@ -1238,10 +1238,15 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Reward based on PnL
|
# Reward based on PnL
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
|
|
||||||
|
# Extra reward for closing at a predicted high
|
||||||
|
if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
|
||||||
|
reward += 1.0
|
||||||
|
logger.info("Closing long at predicted high - additional reward")
|
||||||
else:
|
else:
|
||||||
reward = -1.0 # Negative reward for loss
|
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
|
||||||
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
@ -1296,15 +1301,15 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Reward based on PnL
|
# Reward based on PnL
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
|
|
||||||
# Extra reward for closing at a predicted high
|
# Extra reward for closing at a predicted high
|
||||||
if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
|
if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
|
||||||
reward += 0.5
|
reward += 1.0
|
||||||
logger.info("Closing long at predicted high - additional reward")
|
logger.info("Closing long at predicted high - additional reward")
|
||||||
else:
|
else:
|
||||||
reward = -1.0 # Negative reward for loss
|
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
|
||||||
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
@ -1354,15 +1359,15 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Reward based on PnL
|
# Reward based on PnL
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
|
|
||||||
# Extra reward for closing at a predicted low
|
# Extra reward for closing at a predicted low
|
||||||
if hasattr(self, 'has_predicted_low') and self.has_predicted_low:
|
if hasattr(self, 'has_predicted_low') and self.has_predicted_low:
|
||||||
reward += 0.5
|
reward += 1.0
|
||||||
logger.info("Closing short at predicted low - additional reward")
|
logger.info("Closing short at predicted low - additional reward")
|
||||||
else:
|
else:
|
||||||
reward = -1.0 # Negative reward for loss
|
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
|
||||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
@ -1378,9 +1383,9 @@ class TradingEnvironment:
|
|||||||
# Add reward based on direct PnL change
|
# Add reward based on direct PnL change
|
||||||
balance_change = self.balance - prev_balance
|
balance_change = self.balance - prev_balance
|
||||||
if balance_change > 0:
|
if balance_change > 0:
|
||||||
reward += balance_change * 0.5 # Positive reward for making money
|
reward += balance_change * 1.0 # Increased positive reward for making money
|
||||||
else:
|
else:
|
||||||
reward += balance_change * 1.0 # Stronger negative reward for losing money
|
reward += balance_change * 2.0 # Stronger negative reward for losing money
|
||||||
|
|
||||||
# Add reward for predicted price movement alignment
|
# Add reward for predicted price movement alignment
|
||||||
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
||||||
@ -1611,9 +1616,15 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
def initialize_price_predictor(self, device="cpu"):
|
def initialize_price_predictor(self, device="cpu"):
|
||||||
"""Initialize the price prediction model"""
|
"""Initialize the price prediction model"""
|
||||||
|
# Only create a new model if one doesn't already exist
|
||||||
|
if not hasattr(self, 'price_predictor') or self.price_predictor is None:
|
||||||
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
||||||
self.price_predictor.to(device)
|
self.price_predictor.to(device)
|
||||||
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
||||||
|
else:
|
||||||
|
# If model exists, just ensure it's on the right device
|
||||||
|
self.price_predictor.to(device)
|
||||||
|
|
||||||
self.predicted_prices = np.array([])
|
self.predicted_prices = np.array([])
|
||||||
self.predicted_extrema = np.array([])
|
self.predicted_extrema = np.array([])
|
||||||
self.extrema_threshold = 0.7 # Threshold for extrema prediction confidence
|
self.extrema_threshold = 0.7 # Threshold for extrema prediction confidence
|
||||||
@ -1766,16 +1777,16 @@ class TradingEnvironment:
|
|||||||
return fee
|
return fee
|
||||||
|
|
||||||
# Ensure GPU usage if available
|
# Ensure GPU usage if available
|
||||||
def get_device(device_preference='gpu'):
|
def get_device(device_preference='auto'):
|
||||||
"""Get the device to use (GPU or CPU) based on preference and availability"""
|
"""Get the device to use (GPU or CPU) based on preference and availability"""
|
||||||
if device_preference.lower() == 'gpu' and torch.cuda.is_available():
|
if device_preference.lower() in ['gpu', 'auto'] and torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
# Set default tensor type to float32 for CUDA
|
# Set default tensor type to float32 for CUDA
|
||||||
torch.set_default_tensor_type(torch.FloatTensor)
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if device_preference.lower() == 'gpu':
|
if device_preference.lower() in ['gpu', 'auto']:
|
||||||
logger.info("GPU requested but not available, using CPU instead")
|
logger.info("GPU requested but not available, using CPU instead")
|
||||||
else:
|
else:
|
||||||
logger.info("Using CPU as requested")
|
logger.info("Using CPU as requested")
|
||||||
@ -1952,7 +1963,7 @@ class Agent:
|
|||||||
|
|
||||||
# Use mixed precision for forward/backward passes
|
# Use mixed precision for forward/backward passes
|
||||||
if self.device.type == "cuda":
|
if self.device.type == "cuda":
|
||||||
with amp.autocast():
|
with amp.autocast(device_type='cuda'):
|
||||||
# Compute Q values
|
# Compute Q values
|
||||||
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
||||||
|
|
||||||
@ -2943,10 +2954,12 @@ async def main():
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run the trading bot')
|
parser = argparse.ArgumentParser(description='Run the trading bot')
|
||||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'], help='Mode to run the bot in')
|
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live', 'continuous'], help='Mode to run the bot in')
|
||||||
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train for')
|
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train for')
|
||||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||||
parser.add_argument('--device', type=str, default='auto', choices=['cpu', 'gpu', 'auto'], help='Device to use for training')
|
parser.add_argument('--device', type=str, default='auto', choices=['cpu', 'gpu', 'auto'], help='Device to use for training')
|
||||||
|
parser.add_argument('--refresh-data', '--refresh_data', dest='refresh_data', action='store_true', help='Refresh data at the start of each episode')
|
||||||
|
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for data (e.g., 1s, 1m, 5m, 15m, 1h)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
@ -2995,6 +3008,247 @@ async def main():
|
|||||||
results = evaluate_agent(agent, env, num_episodes=10)
|
results = evaluate_agent(agent, env, num_episodes=10)
|
||||||
logger.info(f"Evaluation results: {results}")
|
logger.info(f"Evaluation results: {results}")
|
||||||
|
|
||||||
|
elif args.mode == 'continuous':
|
||||||
|
# Continuous training mode - train indefinitely with data refreshing
|
||||||
|
logger.info("Starting continuous training mode...")
|
||||||
|
|
||||||
|
# Set refresh_data to True for continuous mode
|
||||||
|
args.refresh_data = True
|
||||||
|
|
||||||
|
# Create directories for continuous models
|
||||||
|
os.makedirs("models", exist_ok=True)
|
||||||
|
|
||||||
|
# Track best PnL for model selection
|
||||||
|
best_pnl = float('-inf')
|
||||||
|
best_pnl_model_path = "models/trading_agent_best_pnl.pt"
|
||||||
|
|
||||||
|
# Load the best PnL model if it exists
|
||||||
|
if os.path.exists(best_pnl_model_path):
|
||||||
|
logger.info(f"Loading best PnL model: {best_pnl_model_path}")
|
||||||
|
agent.load(best_pnl_model_path)
|
||||||
|
|
||||||
|
# Try to load best PnL value from checkpoint file
|
||||||
|
checkpoint_info_path = "checkpoints/best_metrics.json"
|
||||||
|
if os.path.exists(checkpoint_info_path):
|
||||||
|
with open(checkpoint_info_path, 'r') as f:
|
||||||
|
best_metrics = json.load(f)
|
||||||
|
best_pnl = best_metrics.get('best_pnl', best_pnl)
|
||||||
|
logger.info(f"Loaded best PnL from checkpoint: ${best_pnl:.2f}")
|
||||||
|
|
||||||
|
# Initialize episode counter
|
||||||
|
episode = 0
|
||||||
|
|
||||||
|
# Get timeframe from args
|
||||||
|
timeframe = args.timeframe
|
||||||
|
logger.info(f"Using timeframe: {timeframe}")
|
||||||
|
|
||||||
|
# Initialize TensorBoard writer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
tensorboard_dir = f"runs/continuous_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
writer = SummaryWriter(tensorboard_dir)
|
||||||
|
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
|
||||||
|
|
||||||
|
# Attach writer to agent
|
||||||
|
agent.writer = writer
|
||||||
|
|
||||||
|
# Initialize stats dictionary for plotting
|
||||||
|
stats = {
|
||||||
|
'episode_rewards': [],
|
||||||
|
'episode_profits': [],
|
||||||
|
'win_rates': [],
|
||||||
|
'trade_counts': [],
|
||||||
|
'prediction_accuracies': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Train continuously
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
logger.info(f"Continuous training - Episode {episode}")
|
||||||
|
|
||||||
|
# Refresh data from exchange with the specified timeframe
|
||||||
|
logger.info(f"Refreshing market data with timeframe {timeframe}...")
|
||||||
|
await env.fetch_new_data(exchange, "ETH/USDT", timeframe, 100)
|
||||||
|
|
||||||
|
# Reset environment
|
||||||
|
state = env.reset()
|
||||||
|
|
||||||
|
# Initialize price predictor if not already initialized
|
||||||
|
if not hasattr(env, 'price_predictor') or env.price_predictor is None:
|
||||||
|
logger.info("Initializing price predictor...")
|
||||||
|
env.initialize_price_predictor(device=agent.device)
|
||||||
|
|
||||||
|
# Initialize episode variables
|
||||||
|
episode_reward = 0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# Train price predictor
|
||||||
|
prediction_loss, extrema_loss = env.train_price_predictor()
|
||||||
|
|
||||||
|
# Update price predictions
|
||||||
|
env.update_price_predictions()
|
||||||
|
|
||||||
|
# Training loop for this episode
|
||||||
|
while not done:
|
||||||
|
# Select action
|
||||||
|
action = agent.select_action(state)
|
||||||
|
|
||||||
|
# Take action
|
||||||
|
next_state, reward, done = env.step(action)
|
||||||
|
|
||||||
|
# Store experience
|
||||||
|
agent.memory.push(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
# Learn from experience
|
||||||
|
loss = agent.learn()
|
||||||
|
|
||||||
|
# Update state and reward
|
||||||
|
state = next_state
|
||||||
|
episode_reward += reward
|
||||||
|
|
||||||
|
# Calculate win rate
|
||||||
|
total_trades = env.win_count + env.loss_count
|
||||||
|
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
|
||||||
|
|
||||||
|
# Calculate prediction accuracy
|
||||||
|
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
||||||
|
# Compare predictions with actual prices
|
||||||
|
actual_prices = env.features['price'][-len(env.predicted_prices):]
|
||||||
|
prediction_errors = np.abs(env.predicted_prices - actual_prices) / actual_prices
|
||||||
|
prediction_accuracy = 100 * (1 - np.mean(prediction_errors))
|
||||||
|
else:
|
||||||
|
prediction_accuracy = 0
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
stats['episode_rewards'].append(episode_reward)
|
||||||
|
stats['episode_profits'].append(env.episode_pnl)
|
||||||
|
stats['win_rates'].append(win_rate)
|
||||||
|
stats['trade_counts'].append(total_trades)
|
||||||
|
stats['prediction_accuracies'].append(prediction_accuracy)
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
writer.add_scalar('Reward/continuous', episode_reward, episode)
|
||||||
|
writer.add_scalar('Balance/continuous', env.balance, episode)
|
||||||
|
writer.add_scalar('WinRate/continuous', win_rate, episode)
|
||||||
|
writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
||||||
|
writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
||||||
|
writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
||||||
|
writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
||||||
|
writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
|
||||||
|
|
||||||
|
# Log OHLCV data to TensorBoard every 5 episodes
|
||||||
|
if episode % 5 == 0:
|
||||||
|
# Create a DataFrame from the environment's data
|
||||||
|
df_ohlcv = pd.DataFrame([{
|
||||||
|
'timestamp': candle['timestamp'],
|
||||||
|
'open': candle['open'],
|
||||||
|
'high': candle['high'],
|
||||||
|
'low': candle['low'],
|
||||||
|
'close': candle['close'],
|
||||||
|
'volume': candle['volume']
|
||||||
|
} for candle in env.data[-100:]]) # Use last 100 candles
|
||||||
|
|
||||||
|
# Convert timestamp to datetime
|
||||||
|
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
|
||||||
|
df_ohlcv.set_index('timestamp', inplace=True)
|
||||||
|
|
||||||
|
# Extract buy/sell signals from trades
|
||||||
|
buy_signals = []
|
||||||
|
sell_signals = []
|
||||||
|
|
||||||
|
if hasattr(env, 'trades') and env.trades:
|
||||||
|
for trade in env.trades:
|
||||||
|
if 'entry_time' in trade and 'entry' in trade:
|
||||||
|
if trade['type'] == 'long':
|
||||||
|
# Buy signal
|
||||||
|
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||||
|
buy_signals.append((entry_time, trade['entry']))
|
||||||
|
|
||||||
|
# Sell signal if closed
|
||||||
|
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||||
|
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||||
|
sell_signals.append((exit_time, trade['exit']))
|
||||||
|
|
||||||
|
elif trade['type'] == 'short':
|
||||||
|
# Sell short signal
|
||||||
|
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||||
|
sell_signals.append((entry_time, trade['entry']))
|
||||||
|
|
||||||
|
# Buy to cover signal if closed
|
||||||
|
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||||
|
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||||
|
buy_signals.append((exit_time, trade['exit']))
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
log_ohlcv_to_tensorboard(
|
||||||
|
writer,
|
||||||
|
df_ohlcv,
|
||||||
|
buy_signals,
|
||||||
|
sell_signals,
|
||||||
|
episode,
|
||||||
|
tag_prefix=f"continuous_episode_{episode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
||||||
|
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
||||||
|
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}")
|
||||||
|
|
||||||
|
# Create visualization every 10 episodes
|
||||||
|
if episode % 10 == 0:
|
||||||
|
# Create visualization
|
||||||
|
os.makedirs("visualizations", exist_ok=True)
|
||||||
|
visualize_training_results(env, agent, episode)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
model_path = f"models/trading_agent_continuous_{episode}.pt"
|
||||||
|
agent.save(model_path)
|
||||||
|
logger.info(f"Saved continuous model: {model_path}")
|
||||||
|
|
||||||
|
# Plot training results
|
||||||
|
plot_training_results(stats)
|
||||||
|
|
||||||
|
# Save best PnL model
|
||||||
|
if env.episode_pnl > best_pnl:
|
||||||
|
best_pnl = env.episode_pnl
|
||||||
|
agent.save(best_pnl_model_path)
|
||||||
|
logger.info(f"New best PnL model saved: ${env.episode_pnl:.2f}")
|
||||||
|
|
||||||
|
# Save best metrics to resume training if interrupted
|
||||||
|
best_metrics = {
|
||||||
|
'best_pnl': float(best_pnl),
|
||||||
|
'last_episode': episode,
|
||||||
|
'timestamp': datetime.datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
os.makedirs("checkpoints", exist_ok=True)
|
||||||
|
with open("checkpoints/best_metrics.json", 'w') as f:
|
||||||
|
json.dump(best_metrics, f)
|
||||||
|
|
||||||
|
# Update target network
|
||||||
|
agent.update_target_network()
|
||||||
|
|
||||||
|
# Increment episode counter
|
||||||
|
episode += 1
|
||||||
|
|
||||||
|
# Sleep briefly to prevent overwhelming the system
|
||||||
|
# Use shorter sleep for shorter timeframes
|
||||||
|
if timeframe.endswith('s'):
|
||||||
|
await asyncio.sleep(0.1) # Very short sleep for second-based timeframes
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Continuous training stopped by user")
|
||||||
|
# Save final model
|
||||||
|
agent.save("models/trading_agent_continuous_final.pt")
|
||||||
|
# Close TensorBoard writer
|
||||||
|
writer.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in continuous training: {e}")
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
# Save emergency model
|
||||||
|
agent.save(f"models/trading_agent_continuous_emergency_{episode}.pt")
|
||||||
|
# Close TensorBoard writer
|
||||||
|
writer.close()
|
||||||
|
|
||||||
elif args.mode == 'evaluate':
|
elif args.mode == 'evaluate':
|
||||||
# Load the best model
|
# Load the best model
|
||||||
agent.load("models/trading_agent_best_pnl.pt")
|
agent.load("models/trading_agent_best_pnl.pt")
|
||||||
|
304
crypto/gogo2/mexc_trading.py
Normal file
304
crypto/gogo2/mexc_trading.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
import numpy as np
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from mexc_api.spot import Spot
|
||||||
|
from mexc_api.common.enums import Side, OrderType
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger('mexc_trading')
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
||||||
|
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
||||||
|
|
||||||
|
class MexcTradingClient:
|
||||||
|
"""Client for executing trades on MEXC exchange using the official API"""
|
||||||
|
|
||||||
|
def __init__(self, api_key=None, api_secret=None, symbol="ETH/USDT", leverage=1):
|
||||||
|
"""Initialize the MEXC trading client"""
|
||||||
|
self.api_key = api_key or MEXC_API_KEY
|
||||||
|
self.api_secret = api_secret or MEXC_SECRET_KEY
|
||||||
|
|
||||||
|
# Ensure API keys are not None
|
||||||
|
if not self.api_key or not self.api_secret:
|
||||||
|
logger.warning("API keys not provided. Using empty strings for public endpoints only.")
|
||||||
|
self.api_key = ""
|
||||||
|
self.api_secret = ""
|
||||||
|
|
||||||
|
self.symbol = symbol
|
||||||
|
self.formatted_symbol = symbol.replace('/', '') # MEXC requires no slash
|
||||||
|
self.leverage = leverage
|
||||||
|
self.client = None
|
||||||
|
self.position = 'flat' # 'flat', 'long', or 'short'
|
||||||
|
self.position_size = 0
|
||||||
|
self.entry_price = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
self.open_orders = []
|
||||||
|
self.order_history = []
|
||||||
|
self.trades = []
|
||||||
|
self.win_count = 0
|
||||||
|
self.loss_count = 0
|
||||||
|
|
||||||
|
# Initialize the MEXC API client
|
||||||
|
self.initialize_client()
|
||||||
|
|
||||||
|
def initialize_client(self):
|
||||||
|
"""Initialize the MEXC API client using the API"""
|
||||||
|
try:
|
||||||
|
self.client = Spot(self.api_key, self.api_secret)
|
||||||
|
# Test connection
|
||||||
|
server_time = self.client.market.server_time()
|
||||||
|
logger.info(f"MEXC API client initialized successfully. Server time: {server_time}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize MEXC API client: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fetch_account_balance(self):
|
||||||
|
"""Fetch account balance from MEXC API"""
|
||||||
|
try:
|
||||||
|
# Check if we have API keys for private endpoints
|
||||||
|
if not self.api_key or self.api_key == "":
|
||||||
|
logger.warning("No API keys provided. Cannot fetch account balance.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
account_info = self.client.account.account_info()
|
||||||
|
if 'balances' in account_info:
|
||||||
|
# Find USDT balance
|
||||||
|
for asset in account_info['balances']:
|
||||||
|
if asset['asset'] == 'USDT':
|
||||||
|
return float(asset['free'])
|
||||||
|
|
||||||
|
logger.warning("Could not find USDT balance")
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching account balance: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def fetch_open_positions(self):
|
||||||
|
"""Fetch open positions from MEXC API"""
|
||||||
|
try:
|
||||||
|
# Check if we have API keys for private endpoints
|
||||||
|
if not self.api_key or self.api_key == "":
|
||||||
|
logger.warning("No API keys provided. Cannot fetch open positions.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch open orders
|
||||||
|
open_orders = self.client.account.open_orders(self.formatted_symbol)
|
||||||
|
return open_orders
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching open positions: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
|
||||||
|
"""Open a new position using MEXC API"""
|
||||||
|
try:
|
||||||
|
# Check if we have API keys for private endpoints
|
||||||
|
if not self.api_key or self.api_key == "":
|
||||||
|
logger.warning("No API keys provided. Cannot open position.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Calculate quantity based on size and price
|
||||||
|
quantity = size / entry_price
|
||||||
|
# Round quantity to appropriate precision
|
||||||
|
quantity = round(quantity, 4) # Adjust precision as needed for your asset
|
||||||
|
|
||||||
|
# Determine order side
|
||||||
|
side = Side.BUY if position_type == 'long' else Side.SELL
|
||||||
|
|
||||||
|
logger.info(f"Opening {position_type} position: {quantity} {self.symbol} at market price")
|
||||||
|
|
||||||
|
# Place market order
|
||||||
|
order_result = self.client.account.new_order(
|
||||||
|
self.formatted_symbol,
|
||||||
|
side,
|
||||||
|
OrderType.MARKET,
|
||||||
|
str(quantity)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Market order result: {order_result}")
|
||||||
|
|
||||||
|
# Check if order was filled
|
||||||
|
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
|
||||||
|
# Get actual entry price
|
||||||
|
actual_entry_price = float(order_result.get('price', entry_price))
|
||||||
|
|
||||||
|
# Place stop loss order
|
||||||
|
sl_side = Side.SELL if position_type == 'long' else Side.BUY
|
||||||
|
sl_order = self.client.account.new_order(
|
||||||
|
self.formatted_symbol,
|
||||||
|
sl_side,
|
||||||
|
OrderType.STOP_LOSS_LIMIT,
|
||||||
|
str(quantity),
|
||||||
|
price=str(stop_loss),
|
||||||
|
stop_price=str(stop_loss),
|
||||||
|
time_in_force="GTC"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Stop loss order placed: {sl_order}")
|
||||||
|
|
||||||
|
# Place take profit order
|
||||||
|
tp_side = Side.SELL if position_type == 'long' else Side.BUY
|
||||||
|
tp_order = self.client.account.new_order(
|
||||||
|
self.formatted_symbol,
|
||||||
|
tp_side,
|
||||||
|
OrderType.TAKE_PROFIT_LIMIT,
|
||||||
|
str(quantity),
|
||||||
|
price=str(take_profit),
|
||||||
|
stop_price=str(take_profit),
|
||||||
|
time_in_force="GTC"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Take profit order placed: {tp_order}")
|
||||||
|
|
||||||
|
# Update local state
|
||||||
|
self.position = position_type
|
||||||
|
self.position_size = size
|
||||||
|
self.entry_price = actual_entry_price
|
||||||
|
self.stop_loss = stop_loss
|
||||||
|
self.take_profit = take_profit
|
||||||
|
|
||||||
|
# Track orders
|
||||||
|
self.open_orders.extend([sl_order, tp_order])
|
||||||
|
self.order_history.append(order_result)
|
||||||
|
|
||||||
|
logger.info(f"Successfully opened {position_type} position at {actual_entry_price}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to open position: {order_result}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error opening position: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close_position(self, reason="manual_close"):
|
||||||
|
"""Close an existing position"""
|
||||||
|
if self.position == 'flat':
|
||||||
|
logger.info("No position to close")
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if we have API keys for private endpoints
|
||||||
|
if not self.api_key or self.api_key == "":
|
||||||
|
logger.warning("No API keys provided. Cannot close position.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# First, cancel any existing stop loss/take profit orders
|
||||||
|
try:
|
||||||
|
self.client.account.cancel_open_orders(self.formatted_symbol)
|
||||||
|
logger.info("Canceled all open orders")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error canceling open orders: {e}")
|
||||||
|
|
||||||
|
# Determine order side (opposite of position)
|
||||||
|
side = Side.SELL if self.position == 'long' else Side.BUY
|
||||||
|
|
||||||
|
# Calculate quantity
|
||||||
|
quantity = self.position_size / self.entry_price
|
||||||
|
# Round quantity to appropriate precision
|
||||||
|
quantity = round(quantity, 4) # Adjust precision as needed
|
||||||
|
|
||||||
|
logger.info(f"Closing {self.position} position: {quantity} {self.symbol} at market price")
|
||||||
|
|
||||||
|
# Execute market order to close position
|
||||||
|
order_result = self.client.account.new_order(
|
||||||
|
self.formatted_symbol,
|
||||||
|
side,
|
||||||
|
OrderType.MARKET,
|
||||||
|
str(quantity)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Close order result: {order_result}")
|
||||||
|
|
||||||
|
# Check if order was filled
|
||||||
|
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
|
||||||
|
# Get actual exit price
|
||||||
|
exit_price = float(order_result.get('price', 0))
|
||||||
|
|
||||||
|
# Calculate PnL
|
||||||
|
if self.position == 'long':
|
||||||
|
pnl_percent = (exit_price - self.entry_price) / self.entry_price * 100
|
||||||
|
else: # short
|
||||||
|
pnl_percent = (self.entry_price - exit_price) / self.entry_price * 100
|
||||||
|
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
self.trades.append({
|
||||||
|
'type': self.position,
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': exit_price,
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'reason': reason,
|
||||||
|
'order_id': order_result.get('orderId')
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
# Track order history
|
||||||
|
self.order_history.append(order_result)
|
||||||
|
|
||||||
|
logger.info(f"Closed {self.position} position at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to close position: {order_result}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing position: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def check_order_status(self, order_id):
|
||||||
|
"""Check the status of a specific order"""
|
||||||
|
try:
|
||||||
|
# Check if we have API keys for private endpoints
|
||||||
|
if not self.api_key or self.api_key == "":
|
||||||
|
logger.warning("No API keys provided. Cannot check order status.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
order_status = self.client.account.query_order(
|
||||||
|
self.formatted_symbol,
|
||||||
|
order_id=order_id
|
||||||
|
)
|
||||||
|
return order_status
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking order status: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_market_price(self):
|
||||||
|
"""Get current market price for the symbol"""
|
||||||
|
try:
|
||||||
|
ticker = self.client.market.ticker_price(self.formatted_symbol)
|
||||||
|
if isinstance(ticker, list) and len(ticker) > 0 and 'price' in ticker[0]:
|
||||||
|
return float(ticker[0]['price'])
|
||||||
|
elif isinstance(ticker, dict) and 'price' in ticker:
|
||||||
|
return float(ticker['price'])
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected ticker format: {ticker}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting market price: {e}")
|
||||||
|
return None
|
456
crypto/gogo2/mexcapi-readme.md
Normal file
456
crypto/gogo2/mexcapi-readme.md
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
# mexc-api-sdk
|
||||||
|
|
||||||
|
MEXC Official Market and trade api sdk, easy to connection and send request to MEXC open api !
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- To use our SDK you have to install nodejs LTS (https://aws.github.io/jsii/user-guides/lib-user/)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
1.
|
||||||
|
```
|
||||||
|
git clone https://github.com/mxcdevelop/mexc-api-sdk.git
|
||||||
|
```
|
||||||
|
2. cd dist/{language} and unzip the file
|
||||||
|
3. we offer five language : dotnet, go, java, js, python
|
||||||
|
|
||||||
|
## Table of APIS
|
||||||
|
- [Init](#init)
|
||||||
|
- [Market](#market)
|
||||||
|
- [Ping](#ping)
|
||||||
|
- [Check Server Time](#check-server-time)
|
||||||
|
- [Exchange Information](#exchange-information)
|
||||||
|
- [Recent Trades List](#recent-trades-list)
|
||||||
|
- [Order Book](#order-book)
|
||||||
|
- [Old Trade Lookup](#old-trade-lookup)
|
||||||
|
- [Aggregate Trades List](#aggregate-trades-list)
|
||||||
|
- [kline Data](#kline-data)
|
||||||
|
- [Current Average Price](#current-average-price)
|
||||||
|
- [24hr Ticker Price Change Statistics](#24hr-ticker-price-change-statistics)
|
||||||
|
- [Symbol Price Ticker](#symbol-price-ticker)
|
||||||
|
- [Symbol Order Book Ticker](#symbol-order-book-ticker)
|
||||||
|
- [Trade](#trade)
|
||||||
|
- [Test New Order](#test-new-order)
|
||||||
|
- [New Order](#new-order)
|
||||||
|
- [cancel-order](#cancel-order)
|
||||||
|
- [Cancel all Open Orders on a Symbol](#cancel-all-open-orders-on-a-symbol)
|
||||||
|
- [Query Order](#query-order)
|
||||||
|
- [Current Open Orders](#current-open-orders)
|
||||||
|
- [All Orders](#all-orders)
|
||||||
|
- [Account Information](#account-information)
|
||||||
|
- [Account Trade List](#account-trade-list)
|
||||||
|
## Init
|
||||||
|
```javascript
|
||||||
|
//Javascript
|
||||||
|
import * as Mexc from 'mexc-sdk';
|
||||||
|
const apiKey = 'apiKey'
|
||||||
|
const apiSecret = 'apiSecret'
|
||||||
|
const client = new Mexc.Spot(apiKey, apiSecret);
|
||||||
|
```
|
||||||
|
```go
|
||||||
|
// Go
|
||||||
|
package main
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"mexc-sdk/mexcsdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
apiKey := "apiKey"
|
||||||
|
apiSecret := "apiSecret"
|
||||||
|
spot := mexcsdk.NewSpot(apiKey, apiSecret)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```python
|
||||||
|
# python
|
||||||
|
from mexc_sdk import Spot
|
||||||
|
spot = Spot(api_key='apiKey', api_secret='apiSecret')
|
||||||
|
```
|
||||||
|
```java
|
||||||
|
// java
|
||||||
|
import Mexc.Sdk.*;
|
||||||
|
class MyClass {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
String apiKey= "apiKey";
|
||||||
|
String apiSecret= "apiSecret";
|
||||||
|
Spot mySpot = new Spot(apiKey, apiSecret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```C#
|
||||||
|
// dotnet
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using Mxc.Sdk;
|
||||||
|
|
||||||
|
namespace dotnet
|
||||||
|
{
|
||||||
|
class Program
|
||||||
|
{
|
||||||
|
static void Main(string[] args)
|
||||||
|
{
|
||||||
|
string apiKey = "apiKey";
|
||||||
|
string apiSecret= "apiSecret";
|
||||||
|
var spot = new Spot(apiKey, apiSecret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Market
|
||||||
|
### Ping
|
||||||
|
```javascript
|
||||||
|
client.ping()
|
||||||
|
```
|
||||||
|
### Check Server Time
|
||||||
|
```javascript
|
||||||
|
client.time()
|
||||||
|
```
|
||||||
|
### Exchange Information
|
||||||
|
```javascript
|
||||||
|
client.exchangeInfo(options: any)
|
||||||
|
options:{symbol, symbols}
|
||||||
|
/**
|
||||||
|
* choose one parameter
|
||||||
|
*
|
||||||
|
* symbol :
|
||||||
|
* example "BNBBTC";
|
||||||
|
*
|
||||||
|
* symbols :
|
||||||
|
* array of symbol
|
||||||
|
* example ["BTCUSDT","BNBBTC"];
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Recent Trades List
|
||||||
|
```javascript
|
||||||
|
client.trades(symbol: string, options: any = { limit: 500 })
|
||||||
|
options:{limit}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 500;
|
||||||
|
* max 1000;
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
### Order Book
|
||||||
|
```javascript
|
||||||
|
client.depth(symbol: string, options: any = { limit: 100 })
|
||||||
|
options:{limit}
|
||||||
|
/**
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 100;
|
||||||
|
* max 5000;
|
||||||
|
* Valid:[5, 10, 20, 50, 100, 500, 1000, 5000]
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Old Trade Lookup
|
||||||
|
```javascript
|
||||||
|
client.historicalTrades(symbol: string, options: any = { limit: 500 })
|
||||||
|
options:{limit, fromId}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 500;
|
||||||
|
* max 1000;
|
||||||
|
*
|
||||||
|
* fromId:
|
||||||
|
* Trade id to fetch from. Default gets most recent trades
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Aggregate Trades List
|
||||||
|
```javascript
|
||||||
|
client.aggTrades(symbol: string, options: any = { limit: 500 })
|
||||||
|
options:{fromId, startTime, endTime, limit}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* fromId :
|
||||||
|
* id to get aggregate trades from INCLUSIVE
|
||||||
|
*
|
||||||
|
* startTime:
|
||||||
|
* start at
|
||||||
|
*
|
||||||
|
* endTime:
|
||||||
|
* end at
|
||||||
|
*
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 500;
|
||||||
|
* max 1000;
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
### kline Data
|
||||||
|
```javascript
|
||||||
|
client.klines(symbol: string, interval: string, options: any = { limit: 500 })
|
||||||
|
options:{ startTime, endTime, limit}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* interval :
|
||||||
|
* m :minute;
|
||||||
|
* h :Hour;
|
||||||
|
* d :day;
|
||||||
|
* w :week;
|
||||||
|
* M :month
|
||||||
|
* example : "1m"
|
||||||
|
*
|
||||||
|
* startTime :
|
||||||
|
* start at
|
||||||
|
*
|
||||||
|
* endTime :
|
||||||
|
* end at
|
||||||
|
*
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 500;
|
||||||
|
* max 1000;
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Current Average Price
|
||||||
|
```javascript
|
||||||
|
client.avgPrice(symbol: string)
|
||||||
|
```
|
||||||
|
### 24hr Ticker Price Change Statistics
|
||||||
|
```javascript
|
||||||
|
client.ticker24hr(symbol?: string)
|
||||||
|
```
|
||||||
|
### Symbol Price Ticker
|
||||||
|
```javascript
|
||||||
|
client.tickerPrice(symbol?: string)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Symbol Order Book Ticker
|
||||||
|
```javascript
|
||||||
|
client.bookTicker(symbol?: string)
|
||||||
|
```
|
||||||
|
## Trade
|
||||||
|
### Test New Order
|
||||||
|
```javascript
|
||||||
|
client.newOrderTest(symbol: string, side: string, orderType: string, options: any = {})
|
||||||
|
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* side:
|
||||||
|
* Order side
|
||||||
|
* ENUM:
|
||||||
|
* BUY
|
||||||
|
* SELL
|
||||||
|
*
|
||||||
|
* orderType:
|
||||||
|
* Order type
|
||||||
|
* ENUM:
|
||||||
|
* LIMIT
|
||||||
|
* MARKET
|
||||||
|
* STOP_LOSS
|
||||||
|
* STOP_LOSS_LIMIT
|
||||||
|
* TAKE_PROFIT
|
||||||
|
* TAKE_PROFIT_LIMIT
|
||||||
|
* LIMIT_MAKER
|
||||||
|
*
|
||||||
|
* timeInForce :
|
||||||
|
* How long an order will be active before expiration.
|
||||||
|
* GTC: Active unless the order is canceled
|
||||||
|
* IOC: Order will try to fill the order as much as it can before the order expires
|
||||||
|
* FOK: Active unless the full order cannot be filled upon execution.
|
||||||
|
*
|
||||||
|
* quantity :
|
||||||
|
* target quantity
|
||||||
|
*
|
||||||
|
* quoteOrderQty :
|
||||||
|
* Specify the total spent or received
|
||||||
|
*
|
||||||
|
* price :
|
||||||
|
* target price
|
||||||
|
*
|
||||||
|
* newClientOrderId :
|
||||||
|
* A unique id among open orders. Automatically generated if not sent
|
||||||
|
*
|
||||||
|
* stopPrice :
|
||||||
|
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
|
||||||
|
*
|
||||||
|
* icebergQty :
|
||||||
|
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
|
||||||
|
*
|
||||||
|
* newOrderRespType :
|
||||||
|
* Set the response JSON. ACK, RESULT, or FULL;
|
||||||
|
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
|
||||||
|
*
|
||||||
|
* recvWindow :
|
||||||
|
* Delay accept time
|
||||||
|
* The value cannot be greater than 60000
|
||||||
|
* defaults: 5000
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### New Order
|
||||||
|
```javascript
|
||||||
|
client.newOrder(symbol: string, side: string, orderType: string, options: any = {})
|
||||||
|
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* side:
|
||||||
|
* Order side
|
||||||
|
* ENUM:
|
||||||
|
* BUY
|
||||||
|
* SELL
|
||||||
|
*
|
||||||
|
* orderType:
|
||||||
|
* Order type
|
||||||
|
* ENUM:
|
||||||
|
* LIMIT
|
||||||
|
* MARKET
|
||||||
|
* STOP_LOSS
|
||||||
|
* STOP_LOSS_LIMIT
|
||||||
|
* TAKE_PROFIT
|
||||||
|
* TAKE_PROFIT_LIMIT
|
||||||
|
* LIMIT_MAKER
|
||||||
|
*
|
||||||
|
* timeInForce :
|
||||||
|
* How long an order will be active before expiration.
|
||||||
|
* GTC: Active unless the order is canceled
|
||||||
|
* IOC: Order will try to fill the order as much as it can before the order expires
|
||||||
|
* FOK: Active unless the full order cannot be filled upon execution.
|
||||||
|
*
|
||||||
|
* quantity :
|
||||||
|
* target quantity
|
||||||
|
*
|
||||||
|
* quoteOrderQty :
|
||||||
|
* Specify the total spent or received
|
||||||
|
*
|
||||||
|
* price :
|
||||||
|
* target price
|
||||||
|
*
|
||||||
|
* newClientOrderId :
|
||||||
|
* A unique id among open orders. Automatically generated if not sent
|
||||||
|
*
|
||||||
|
* stopPrice :
|
||||||
|
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
|
||||||
|
*
|
||||||
|
* icebergQty :
|
||||||
|
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
|
||||||
|
*
|
||||||
|
* newOrderRespType :
|
||||||
|
* Set the response JSON. ACK, RESULT, or FULL;
|
||||||
|
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
|
||||||
|
*
|
||||||
|
* recvWindow :
|
||||||
|
* Delay accept time
|
||||||
|
* The value cannot be greater than 60000
|
||||||
|
* defaults: 5000
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### cancel-order
|
||||||
|
```javascript
|
||||||
|
client.cancelOrder(symbol: string, options:any = {})
|
||||||
|
options:{ orderId, origClientOrderId, newClientOrderId}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* Either orderId or origClientOrderId must be sent
|
||||||
|
*
|
||||||
|
* orderId:
|
||||||
|
* target orderId
|
||||||
|
*
|
||||||
|
* origClientOrderId:
|
||||||
|
* target origClientOrderId
|
||||||
|
*
|
||||||
|
* newClientOrderId:
|
||||||
|
* Used to uniquely identify this cancel. Automatically generated by default.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cancel all Open Orders on a Symbol
|
||||||
|
```javascript
|
||||||
|
client.cancelOpenOrders(symbol: string)
|
||||||
|
```
|
||||||
|
### Query Order
|
||||||
|
```javascript
|
||||||
|
client.queryOrder(symbol: string, options:any = {})
|
||||||
|
options:{ orderId, origClientOrderId}
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* Either orderId or origClientOrderId must be sent
|
||||||
|
*
|
||||||
|
* orderId:
|
||||||
|
* target orderId
|
||||||
|
*
|
||||||
|
* origClientOrderId:
|
||||||
|
* target origClientOrderId
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
### Current Open Orders
|
||||||
|
```javascript
|
||||||
|
client.openOrders(symbol: string)
|
||||||
|
```
|
||||||
|
### All Orders
|
||||||
|
```javascript
|
||||||
|
client.allOrders(symbol: string, options: any = { limit: 500 })
|
||||||
|
options:{ orderId, startTime, endTime, limit}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* orderId:
|
||||||
|
* target orderId
|
||||||
|
*
|
||||||
|
* startTime:
|
||||||
|
* start at
|
||||||
|
*
|
||||||
|
* endTime:
|
||||||
|
* end at
|
||||||
|
*
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 500;
|
||||||
|
* max 1000;
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
||||||
|
### Account Information
|
||||||
|
```javascript
|
||||||
|
client.accountInfo()
|
||||||
|
```
|
||||||
|
### Account Trade List
|
||||||
|
```javascript
|
||||||
|
client.accountTradeList(symbol: string, options:any = { limit: 500 })
|
||||||
|
options:{ orderId, startTime, endTime, fromId, limit}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* orderId:
|
||||||
|
* target orderId
|
||||||
|
*
|
||||||
|
* startTime:
|
||||||
|
* start at
|
||||||
|
*
|
||||||
|
* endTime:
|
||||||
|
* end at
|
||||||
|
*
|
||||||
|
* fromId:
|
||||||
|
* TradeId to fetch from. Default gets most recent trades
|
||||||
|
*
|
||||||
|
* limit :
|
||||||
|
* Number of returned data
|
||||||
|
* Default 500;
|
||||||
|
* max 1000;
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
```
|
@ -1,188 +1,91 @@
|
|||||||
# Crypto Trading Bot with Reinforcement Learning
|
# Crypto Trading Bot with MEXC API Integration
|
||||||
|
|
||||||
An automated cryptocurrency trading bot that uses Deep Q-Learning (DQN) to trade ETH/USDT on the MEXC exchange. The bot features a sophisticated neural network architecture with LSTM layers and attention mechanisms for better pattern recognition.
|
This is an AI-powered cryptocurrency trading bot that can run in both simulation (demo) mode and live trading mode using the MEXC exchange API.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Deep Q-Learning with experience replay
|
- Deep Reinforcement Learning agent for trading decisions
|
||||||
- LSTM layers for sequential data processing
|
- Technical indicators and price prediction
|
||||||
- Multi-head attention mechanism
|
- Live trading integration with MEXC exchange via mexc-api
|
||||||
- Dueling DQN architecture
|
- Demo mode for testing without real trades
|
||||||
- Real-time trading capabilities
|
- Real-time data streaming via websockets
|
||||||
- TensorBoard integration for monitoring
|
- Performance tracking and visualization
|
||||||
- Comprehensive technical indicators
|
|
||||||
- Demo and live trading modes
|
|
||||||
- Automatic model checkpointing
|
|
||||||
|
|
||||||
## Prerequisites
|
## Setup
|
||||||
|
|
||||||
- Python 3.8+
|
1. Clone the repository
|
||||||
- MEXC Exchange API credentials
|
2. Install dependencies:
|
||||||
- GPU recommended but not required
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
1. Clone the repository:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/yourusername/crypto-trading-bot.git
|
|
||||||
cd crypto-trading-bot
|
|
||||||
```
|
```
|
||||||
2. Create a virtual environment:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m venv venv
|
|
||||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
|
||||||
```
|
|
||||||
3. Install dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
3. Create a `.env` file in the root directory with your MEXC API keys:
|
||||||
|
|
||||||
4. Create a `.env` file in the project root with your MEXC API credentials:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
MEXC_API_KEY=your_api_key
|
|
||||||
MEXC_API_SECRET=your_api_secret
|
|
||||||
|
|
||||||
|
|
||||||
cuda support
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
|
||||||
```
|
```
|
||||||
|
MEXC_API_KEY=your_api_key_here
|
||||||
|
MEXC_SECRET_KEY=your_secret_key_here
|
||||||
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
The bot can be run in three modes:
|
The bot can be run in three different modes:
|
||||||
|
|
||||||
### Training Mode
|
### Training Mode
|
||||||
|
|
||||||
```bash
|
Train the agent on historical data:
|
||||||
python main.py --mode train --episodes 1000
|
|
||||||
|
```
|
||||||
|
python main.py --mode train --episodes 100
|
||||||
```
|
```
|
||||||
|
|
||||||
### Evaluation Mode
|
### Evaluation Mode
|
||||||
|
|
||||||
```bash
|
Evaluate the trained agent on historical data:
|
||||||
python main.py --mode eval --episodes 10
|
|
||||||
|
```
|
||||||
|
python main.py --mode evaluate
|
||||||
```
|
```
|
||||||
|
|
||||||
### Live Trading Mode
|
### Live Trading Mode
|
||||||
|
|
||||||
```bash
|
Run the bot in live trading mode:
|
||||||
# Demo mode (simulated trading with real market data)
|
|
||||||
python main.py --mode live --demo
|
|
||||||
|
|
||||||
# Real trading (actual trades on MEXC)
|
```
|
||||||
python main.py --mode live
|
python main.py --mode live
|
||||||
```
|
```
|
||||||
|
|
||||||
Demo mode simulates trading using real-time market data but does not execute actual trades. It still:
|
To run in demo mode (no real trades):
|
||||||
- Logs all trading decisions and performance metrics
|
|
||||||
- Updates the model based on market data (if in training mode)
|
|
||||||
- Displays real-time analytics and position information
|
|
||||||
- Calculates theoretical profits/losses
|
|
||||||
- Saves performance data to TensorBoard
|
|
||||||
|
|
||||||
This makes it perfect for testing strategies without financial risk.
|
```
|
||||||
|
python main.py --mode live --demo
|
||||||
|
```
|
||||||
|
|
||||||
|
## Live Trading Implementation
|
||||||
|
|
||||||
|
The bot uses the mexc-api package to execute trades on the MEXC exchange. The implementation includes:
|
||||||
|
|
||||||
|
- Market order execution for opening and closing positions
|
||||||
|
- Stop loss and take profit orders
|
||||||
|
- Real-time balance updates
|
||||||
|
- Trade history tracking
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
Key parameters can be adjusted in `main.py`:
|
You can adjust the following parameters in `main.py`:
|
||||||
|
|
||||||
- `INITIAL_BALANCE`: Starting balance for training/demo
|
- `INITIAL_BALANCE`: Starting balance for simulation
|
||||||
- `MAX_LEVERAGE`: Maximum leverage for trades
|
- `MAX_LEVERAGE`: Leverage to use for trading
|
||||||
- `STOP_LOSS_PERCENT`: Stop loss percentage
|
- `STOP_LOSS_PERCENT`: Default stop loss percentage
|
||||||
- `TAKE_PROFIT_PERCENT`: Take profit percentage
|
- `TAKE_PROFIT_PERCENT`: Default take profit percentage
|
||||||
- `BATCH_SIZE`: Training batch size
|
|
||||||
- `LEARNING_RATE`: Model learning rate
|
|
||||||
- `STATE_SIZE`: Size of the state representation
|
|
||||||
|
|
||||||
## Model Architecture
|
## Architecture
|
||||||
|
|
||||||
The DQN model includes:
|
|
||||||
- Input layer with technical indicators
|
|
||||||
- LSTM layers for temporal pattern recognition
|
|
||||||
- Multi-head attention mechanism
|
|
||||||
- Dueling architecture for better Q-value estimation
|
|
||||||
- Batch normalization for stable training
|
|
||||||
|
|
||||||
## Monitoring
|
|
||||||
|
|
||||||
Training progress can be monitored using TensorBoard:
|
|
||||||
|
|
||||||
|
|
||||||
Training progress is logged to TensorBoard:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
tensorboard --logdir=logs
|
|
||||||
```
|
|
||||||
|
|
||||||
This will show:
|
|
||||||
- Training rewards
|
|
||||||
- Account balance
|
|
||||||
- Win rate
|
|
||||||
- Loss metrics
|
|
||||||
|
|
||||||
## Trading Strategy
|
|
||||||
|
|
||||||
The bot makes decisions based on:
|
|
||||||
- Price action
|
|
||||||
- Technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
|
||||||
- Historical patterns through LSTM
|
|
||||||
- Risk management with stop-loss and take-profit
|
|
||||||
|
|
||||||
## Safety Features
|
|
||||||
|
|
||||||
- Demo mode for safe testing
|
|
||||||
- Automatic stop-loss
|
|
||||||
- Position size limits
|
|
||||||
- Error handling for API calls
|
|
||||||
- Logging of all actions
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
├── main.py # Main bot implementation
|
|
||||||
├── requirements.txt # Project dependencies
|
|
||||||
├── .env # API credentials
|
|
||||||
├── models/ # Saved model checkpoints
|
|
||||||
├── runs/ # TensorBoard logs
|
|
||||||
└── trading_bot.log # Activity logs
|
|
||||||
|
|
||||||
|
- `main.py`: Main entry point and trading logic
|
||||||
|
- `mexc_trading.py`: MEXC API integration for live trading using mexc-api
|
||||||
|
- `models/`: Directory for saved model weights
|
||||||
|
|
||||||
## Warning
|
## Warning
|
||||||
|
|
||||||
Cryptocurrency trading carries significant risks. This bot is for educational purposes and should not be used with real money without thorough testing and understanding of the risks involved.
|
Trading cryptocurrencies involves significant risk. This bot is provided for educational purposes only. Use at your own risk.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
[MIT License](LICENSE)
|
MIT
|
||||||
|
|
||||||
The main changes I made:
|
|
||||||
Fixed code block formatting by adding proper language identifiers
|
|
||||||
Added missing closing code blocks
|
|
||||||
Properly formatted directory structure
|
|
||||||
Added complete sections that were cut off in the original
|
|
||||||
Ensured consistent formatting throughout the document
|
|
||||||
Added proper bash syntax highlighting for command examples
|
|
||||||
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Edits/improvements
|
|
||||||
|
|
||||||
Fixes the shape mismatch by ensuring the state vector is exactly STATE_SIZE elements
|
|
||||||
Adds robust error handling in the model's forward pass to handle mismatched inputs
|
|
||||||
Adds a transformer encoder for more sophisticated pattern recognition
|
|
||||||
Provides an expand_model method to increase model capacity while preserving learned weights
|
|
||||||
Adds detailed logging about model size and shape mismatches
|
|
||||||
The model now has:
|
|
||||||
Configurable hidden layer sizes
|
|
||||||
Transformer layers for complex pattern recognition
|
|
||||||
LSTM layers for temporal patterns
|
|
||||||
Attention mechanisms for focusing on important features
|
|
||||||
Dueling architecture for better Q-value estimation
|
|
||||||
With hidden_size=256, this model has about 1-2 million parameters. By increasing hidden_size to 512 or 1024, you can easily scale to 5-20 million parameters. For even larger models (billions of parameters), you would need to implement a more distributed architecture with multiple GPUs, which would require significant changes to the training loop.
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
numpy>=1.21.0
|
numpy>=1.20.0
|
||||||
pandas>=1.3.0
|
pandas>=1.3.0
|
||||||
matplotlib>=3.4.0
|
matplotlib>=3.4.0
|
||||||
torch>=1.9.0
|
torch>=1.9.0
|
||||||
python-dotenv>=0.19.0
|
scikit-learn>=0.24.0
|
||||||
ccxt>=2.0.0
|
ccxt>=2.0.0
|
||||||
|
python-dotenv>=0.19.0
|
||||||
websockets>=10.0
|
websockets>=10.0
|
||||||
tensorboard>=2.6.0
|
tensorboard>=2.7.0
|
||||||
scikit-learn
|
mexc-api>=1.0.0
|
||||||
mplfinance
|
asyncio>=3.4.3
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
23
crypto/gogo2/test_mexc_api.py
Normal file
23
crypto/gogo2/test_mexc_api.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from mexc_api.spot import Spot
|
||||||
|
|
||||||
|
def test_mexc_api():
|
||||||
|
try:
|
||||||
|
# Initialize client with empty API keys for public endpoints
|
||||||
|
client = Spot("", "")
|
||||||
|
|
||||||
|
# Test server time endpoint
|
||||||
|
server_time = client.market.server_time()
|
||||||
|
print(f"Server time: {server_time}")
|
||||||
|
|
||||||
|
# Test ticker price endpoint
|
||||||
|
ticker = client.market.ticker_price("ETHUSDT")
|
||||||
|
print(f"ETH/USDT price: {ticker}")
|
||||||
|
|
||||||
|
print("MEXC API is working correctly!")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error testing MEXC API: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_mexc_api()
|
34
crypto/gogo2/test_trading_client.py
Normal file
34
crypto/gogo2/test_trading_client.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from mexc_trading import MexcTradingClient
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
async def test_trading_client():
|
||||||
|
"""Test the MexcTradingClient functionality"""
|
||||||
|
print("Initializing MexcTradingClient...")
|
||||||
|
client = MexcTradingClient(symbol="ETH/USDT")
|
||||||
|
|
||||||
|
# Test getting market price
|
||||||
|
print("Testing get_market_price...")
|
||||||
|
price = await client.get_market_price()
|
||||||
|
print(f"Current ETH/USDT price: {price}")
|
||||||
|
|
||||||
|
# If API keys are provided, test account balance
|
||||||
|
if os.getenv('MEXC_API_KEY') and os.getenv('MEXC_SECRET_KEY'):
|
||||||
|
print("Testing fetch_account_balance...")
|
||||||
|
balance = await client.fetch_account_balance()
|
||||||
|
print(f"Account balance: {balance} USDT")
|
||||||
|
|
||||||
|
print("Testing fetch_open_positions...")
|
||||||
|
positions = await client.fetch_open_positions()
|
||||||
|
print(f"Open positions: {positions}")
|
||||||
|
else:
|
||||||
|
print("No API keys provided. Skipping private endpoint tests.")
|
||||||
|
|
||||||
|
print("MexcTradingClient test completed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_trading_client())
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Before Width: | Height: | Size: 170 KiB After Width: | Height: | Size: 307 KiB |
Binary file not shown.
Before Width: | Height: | Size: 86 KiB After Width: | Height: | Size: 84 KiB |
Loading…
x
Reference in New Issue
Block a user