fixes, lots of new ideas
This commit is contained in:
parent
d9f1bac11c
commit
ad559d8c61
4
.gitignore
vendored
4
.gitignore
vendored
@ -38,4 +38,6 @@ crypto/gogo2/trading_bot.log
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||
*trading_agent_continuous_*.pt
|
||||
*trading_agent_episode_*.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_150.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_*.pt
|
||||
crypto/gogo2/visualizations/training_episode_*.png
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||
|
67
crypto/gogo2/_model.md
Normal file
67
crypto/gogo2/_model.md
Normal file
@ -0,0 +1,67 @@
|
||||
# Neural Network Architecture Analysis for Trading Bot
|
||||
|
||||
## Overview
|
||||
|
||||
This document provides a comprehensive analysis of the neural network architecture used in our trading bot system. The system consists of two main neural network components:
|
||||
|
||||
1. **Price Prediction Model** - Forecasts future price movements and extrema points
|
||||
2. **DQN (Deep Q-Network)** - Makes trading decisions based on state representations
|
||||
|
||||
## 1. Price Prediction Model
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
PricePredictionModel(nn.Module)
|
||||
├── Input Layer: [batch_size, seq_len, 2] (price, volume)
|
||||
├── LSTM Layers: 2 stacked layers with hidden_size=128
|
||||
├── Attention Mechanism: Self-attention with linear projections
|
||||
├── Linear Layer 1: hidden_size → hidden_size
|
||||
├── ReLU Activation
|
||||
├── Linear Layer 2: hidden_size → output_size (5 future prices)
|
||||
└── Output: [batch_size, output_size]
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
**Inputs:**
|
||||
- `price_history`: Sequence of historical prices [batch_size, seq_len]
|
||||
- `volume_history`: Sequence of historical volumes [batch_size, seq_len]
|
||||
|
||||
**Preprocessing:**
|
||||
- Normalization using MinMaxScaler (0-1 range)
|
||||
- Reshaping to [batch_size, seq_len, 2] (price and volume features)
|
||||
|
||||
**Forward Pass:**
|
||||
1. Input data passes through LSTM layers
|
||||
2. Self-attention mechanism applied to LSTM outputs
|
||||
3. Linear layers process the attended features
|
||||
4. Output represents predicted prices for next 5 candles
|
||||
|
||||
**Outputs:**
|
||||
- `predicted_prices`: Array of 5 future price predictions
|
||||
- `predicted_extrema`: Binary indicators for potential price extrema points
|
||||
|
||||
## 2. DQN (Deep Q-Network)
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
DQN(nn.Module)
|
||||
├── Input Layer: [batch_size, state_size]
|
||||
├── Linear Layer 1: state_size → hidden_size (384)
|
||||
├── ReLU Activation
|
||||
├── LSTM Layers: 2 stacked layers with hidden_size=384
|
||||
├── Multi-Head Attention: 4 attention heads
|
||||
├── Linear Layer 2: hidden_size → hidden_size
|
||||
├── ReLU Activation
|
||||
├── Linear Layer 3: hidden_size → action_size (4)
|
||||
└── Output: [batch_size, action_size] (Q-values for each action)
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
**Inputs:**
|
||||
- `state`: Current market state representation [batch_size, state_size]
|
||||
- Price features (normalized prices, returns, volatility)
|
||||
- Technical indicators (RSI, MACD, Stochastic,
|
@ -4,6 +4,12 @@ ensure we use GPU if available to train faster. during training we need to have
|
||||
|
||||
|
||||
|
||||
do we call the trading api proerly when in live mode? in live mode - also validate the current ballance. also ensure trades are executed after executing the orders by checking the open orders.
|
||||
|
||||
our trading data chart (in tensorboard) does not properly displayed - the candles seems displayed multiple times but shifted in time. we also do not correctly show the buy/sell evens on the time axis. we do not show the predicted price on the chart.
|
||||
|
||||
|
||||
|
||||
|
||||
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
|
||||
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
|
||||
@ -21,4 +27,14 @@ Backend tkagg is interactive backend. Turning interactive mode on.
|
||||
|
||||
|
||||
|
||||
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
||||
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
||||
|
||||
|
||||
----------------
|
||||
|
||||
do we train the price prediction model by using old known candles and masking the latest to make him guess the next and then backpropagating the next already known candle ? like a transformer (gpt2) would do? or we use RL for that as well?
|
||||
it seems the model is not learning a lot. we keep hovering about the same starting balance even after some time in training in continious mode
|
||||
|
||||
|
||||
it seems we may need another NN model down the loop jut to predict the extremums of the price.
|
||||
we may have to include a mechanism to calculate the extremums of the price retrospectively and to use that to bootstrap pre-train the model.
|
||||
|
362
crypto/gogo2/archive.py
Normal file
362
crypto/gogo2/archive.py
Normal file
@ -0,0 +1,362 @@
|
||||
class MexcTradingClient:
|
||||
def __init__(self, api_key, secret_key, symbol="ETH/USDT", leverage=50):
|
||||
self.client = ccxt.mexc({
|
||||
'apiKey': api_key,
|
||||
'secret': secret_key,
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
self.symbol = symbol
|
||||
self.leverage = leverage
|
||||
self.position = 'flat'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.trades = []
|
||||
|
||||
def initialize_mexc_client(self, api_key, api_secret):
|
||||
"""Initialize the MEXC API client"""
|
||||
try:
|
||||
from mexc_sdk import Spot
|
||||
self.mexc_client = Spot(api_key=api_key, api_secret=api_secret)
|
||||
# Test connection
|
||||
self.mexc_client.ping()
|
||||
logger.info("MEXC API client initialized successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MEXC API client: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def fetch_account_balance(self):
|
||||
"""Fetch actual account balance from MEXC API"""
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, use simulated balance
|
||||
return self.balance
|
||||
|
||||
try:
|
||||
account_info = self.mexc_client.accountInfo()
|
||||
if 'balances' in account_info:
|
||||
# Find USDT balance
|
||||
for asset in account_info['balances']:
|
||||
if asset['asset'] == 'USDT':
|
||||
return float(asset['free'])
|
||||
|
||||
logger.warning("Could not find USDT balance, using current simulated balance")
|
||||
return self.balance
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching account balance: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# Fallback to simulated balance in case of API error
|
||||
return self.balance
|
||||
|
||||
async def fetch_open_positions(self):
|
||||
"""Fetch actual open positions from MEXC API"""
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, return current simulated position
|
||||
return [{
|
||||
'symbol': 'ETH/USDT',
|
||||
'positionSide': 'LONG' if self.position == 'long' else 'SHORT' if self.position == 'short' else 'NONE',
|
||||
'positionAmt': self.position_size / self.current_price if self.position != 'flat' else 0,
|
||||
'entryPrice': self.entry_price,
|
||||
'unrealizedProfit': self.calculate_unrealized_pnl()
|
||||
}] if self.position != 'flat' else []
|
||||
|
||||
try:
|
||||
# Fetch open positions
|
||||
positions = self.mexc_client.openOrders('ETH/USDT')
|
||||
return positions
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching open positions: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# Fallback to simulated positions in case of API error
|
||||
return []
|
||||
|
||||
def calculate_unrealized_pnl(self):
|
||||
"""Calculate unrealized PnL for the current position"""
|
||||
if self.position == 'flat':
|
||||
return 0.0
|
||||
|
||||
position_value = self.position_size / self.entry_price
|
||||
|
||||
if self.position == 'long':
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||
|
||||
# Apply leverage
|
||||
pnl_percent *= self.leverage
|
||||
|
||||
return position_value * pnl_percent / 100
|
||||
|
||||
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
|
||||
"""Open a new position using MEXC API in live mode, or simulate in demo mode"""
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, simulate opening a position
|
||||
self.position = position_type
|
||||
self.position_size = size
|
||||
self.entry_price = entry_price
|
||||
self.entry_index = self.current_step
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
logger.info(f"DEMO: Opened {position_type.upper()} position at {entry_price} | " +
|
||||
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
|
||||
return True
|
||||
|
||||
try:
|
||||
# In live mode, place actual orders via API
|
||||
symbol = "ETHUSDT" # Format required by MEXC
|
||||
side = "BUY" if position_type == 'long' else "SELL"
|
||||
|
||||
# Calculate quantity based on size and price
|
||||
quantity = size / entry_price
|
||||
|
||||
# Place main order
|
||||
order_result = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
orderType="MARKET",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"leverage": self.leverage,
|
||||
"newOrderRespType": "FULL"
|
||||
}
|
||||
)
|
||||
|
||||
# Check if order executed
|
||||
if order_result.get('status') == 'FILLED':
|
||||
actual_entry_price = float(order_result.get('price', entry_price))
|
||||
|
||||
# Place stop loss order
|
||||
sl_order = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side="SELL" if position_type == 'long' else "BUY",
|
||||
orderType="STOP_LOSS",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"stopPrice": stop_loss,
|
||||
"newOrderRespType": "ACK"
|
||||
}
|
||||
)
|
||||
|
||||
# Place take profit order
|
||||
tp_order = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side="SELL" if position_type == 'long' else "BUY",
|
||||
orderType="TAKE_PROFIT",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"stopPrice": take_profit,
|
||||
"newOrderRespType": "ACK"
|
||||
}
|
||||
)
|
||||
|
||||
# Update local state
|
||||
self.position = position_type
|
||||
self.position_size = size
|
||||
self.entry_price = actual_entry_price
|
||||
self.entry_index = self.current_step
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
# Track orders
|
||||
self.open_orders.extend([sl_order, tp_order])
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"LIVE: Opened {position_type.upper()} position at {actual_entry_price} | " +
|
||||
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
|
||||
return True
|
||||
|
||||
else:
|
||||
logger.error(f"Failed to execute order: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def close_position(self, reason="manual_close"):
|
||||
"""Close the current position using MEXC API in live mode, or simulate in demo mode"""
|
||||
if self.position == 'flat':
|
||||
logger.info("No position to close")
|
||||
return False
|
||||
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, simulate closing a position
|
||||
position_type = self.position
|
||||
entry_price = self.entry_price
|
||||
exit_price = self.current_price
|
||||
position_size = self.position_size
|
||||
|
||||
# Calculate PnL
|
||||
if position_type == 'long':
|
||||
pnl_percent = (exit_price - entry_price) / entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (entry_price - exit_price) / entry_price * 100
|
||||
|
||||
# Apply leverage
|
||||
pnl_percent *= self.leverage
|
||||
|
||||
# Calculate actual PnL
|
||||
pnl_dollar = position_size * pnl_percent / 100
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(position_size)
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
self.episode_pnl += pnl_dollar
|
||||
|
||||
# Update max drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': position_type,
|
||||
'entry': entry_price,
|
||||
'exit': exit_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
'market_direction': self.get_market_direction(),
|
||||
'reason': reason,
|
||||
'leverage': self.leverage
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"DEMO: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
# In live mode, close position via API
|
||||
symbol = "ETHUSDT"
|
||||
position_info = await self.fetch_open_positions()
|
||||
|
||||
if not position_info:
|
||||
logger.warning("No open positions found to close")
|
||||
self.position = 'flat'
|
||||
return False
|
||||
|
||||
# First, cancel any existing stop loss/take profit orders
|
||||
try:
|
||||
self.mexc_client.cancelOpenOrders(symbol)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error canceling open orders: {e}")
|
||||
|
||||
# Close the position with a market order
|
||||
position_type = self.position
|
||||
side = "SELL" if position_type == 'long' else "BUY"
|
||||
quantity = self.position_size / self.current_price
|
||||
|
||||
# Execute order
|
||||
order_result = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
orderType="MARKET",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"newOrderRespType": "FULL"
|
||||
}
|
||||
)
|
||||
|
||||
# Check if order executed
|
||||
if order_result.get('status') == 'FILLED':
|
||||
exit_price = float(order_result.get('price', self.current_price))
|
||||
entry_price = self.entry_price
|
||||
position_size = self.position_size
|
||||
|
||||
# Calculate PnL
|
||||
if position_type == 'long':
|
||||
pnl_percent = (exit_price - entry_price) / entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (entry_price - exit_price) / entry_price * 100
|
||||
|
||||
# Apply leverage
|
||||
pnl_percent *= self.leverage
|
||||
|
||||
# Calculate actual PnL
|
||||
pnl_dollar = position_size * pnl_percent / 100
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(position_size)
|
||||
|
||||
# Update balance from API
|
||||
self.balance = await self.fetch_account_balance()
|
||||
self.total_pnl += pnl_dollar
|
||||
self.episode_pnl += pnl_dollar
|
||||
|
||||
# Update max drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': position_type,
|
||||
'entry': entry_price,
|
||||
'exit': exit_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
'market_direction': self.get_market_direction(),
|
||||
'reason': reason,
|
||||
'leverage': self.leverage,
|
||||
'order_id': order_result.get('orderId')
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Track order history
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"LIVE: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to close position: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 30, "timestamp": "2025-03-10T17:57:19.913481"}
|
||||
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 0, "timestamp": "2025-03-12T00:23:19.125190"}
|
207
crypto/gogo2/count_params.py
Normal file
207
crypto/gogo2/count_params.py
Normal file
@ -0,0 +1,207 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
class PricePredictionModel(nn.Module):
|
||||
def __init__(self, input_dim=2, hidden_dim=128, num_layers=2, output_dim=5):
|
||||
super(PricePredictionModel, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# LSTM layers
|
||||
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
|
||||
|
||||
# Self-attention mechanism
|
||||
self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
|
||||
|
||||
# Fully connected layer for price prediction
|
||||
self.price_fc = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
# Fully connected layer for extrema prediction (high and low points)
|
||||
self.extrema_fc = nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (batch_size, seq_len, input_dim)
|
||||
|
||||
# LSTM forward pass
|
||||
lstm_out, _ = self.lstm(x) # lstm_out: (batch_size, seq_len, hidden_dim)
|
||||
|
||||
# Self-attention
|
||||
attn_output, _ = self.attention(lstm_out, lstm_out, lstm_out)
|
||||
|
||||
# Price prediction
|
||||
price_pred = self.price_fc(attn_output[:, -1, :]) # Use the last time step
|
||||
|
||||
# Extrema prediction
|
||||
extrema_logits = self.extrema_fc(attn_output[:, -1, :])
|
||||
|
||||
return price_pred, extrema_logits
|
||||
|
||||
class DQN(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
# Feature extraction layers
|
||||
self.feature_extraction = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
# Advantage stream
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, action_dim)
|
||||
)
|
||||
|
||||
# Value stream
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
# Transformer for temporal dependencies
|
||||
encoder_layers = TransformerEncoderLayer(d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||
|
||||
# LSTM for sequential decision making
|
||||
self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
|
||||
|
||||
# Final layers
|
||||
self.final_layers = nn.Sequential(
|
||||
nn.Linear(hidden_dim*2, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, action_dim)
|
||||
)
|
||||
|
||||
def forward(self, state, hidden=None):
|
||||
# Extract features
|
||||
features = self.feature_extraction(state)
|
||||
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
|
||||
|
||||
# Transformer processing
|
||||
transformer_out = self.transformer(features)
|
||||
|
||||
# LSTM processing
|
||||
lstm_out, lstm_hidden = self.lstm(transformer_out)
|
||||
|
||||
# Dueling architecture
|
||||
advantage = self.advantage_stream(features.squeeze(1))
|
||||
value = self.value_stream(features.squeeze(1))
|
||||
|
||||
# Combine transformer, LSTM and dueling outputs
|
||||
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
|
||||
q_values = self.final_layers(combined)
|
||||
|
||||
# Dueling Q-value computation
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
return q_values, lstm_hidden
|
||||
|
||||
def count_parameters(model):
|
||||
total_params = 0
|
||||
layer_params = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
param_count = param.numel()
|
||||
total_params += param_count
|
||||
layer_params[name] = (param_count, param.shape)
|
||||
|
||||
return total_params, layer_params
|
||||
|
||||
def main():
|
||||
# Initialize the Price Prediction Model
|
||||
price_model = PricePredictionModel()
|
||||
price_total_params, price_layer_params = count_parameters(price_model)
|
||||
|
||||
print(f"Price Prediction Model parameters: {price_total_params:,}")
|
||||
print("\nPrice Prediction Model Layers:")
|
||||
for name, (count, shape) in price_layer_params.items():
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
|
||||
# Initialize the DQN Model with typical dimensions
|
||||
state_dim = 50 # Typical state dimension for the trading bot
|
||||
action_dim = 3 # Typical action dimension (buy, sell, hold)
|
||||
dqn_model = DQN(state_dim=state_dim, action_dim=action_dim)
|
||||
dqn_total_params, dqn_layer_params = count_parameters(dqn_model)
|
||||
|
||||
# Count parameters by category
|
||||
feature_extraction_params = sum(count for name, (count, _) in dqn_layer_params.items() if "feature_extraction" in name)
|
||||
advantage_value_params = sum(count for name, (count, _) in dqn_layer_params.items() if "advantage_stream" in name or "value_stream" in name)
|
||||
transformer_params = sum(count for name, (count, _) in dqn_layer_params.items() if "transformer" in name)
|
||||
lstm_params = sum(count for name, (count, _) in dqn_layer_params.items() if "lstm" in name and "transformer" not in name)
|
||||
final_layers_params = sum(count for name, (count, _) in dqn_layer_params.items() if "final_layers" in name)
|
||||
|
||||
print(f"\nDQN Model parameters: {dqn_total_params:,}")
|
||||
|
||||
# Create sets to track which parameters we've printed
|
||||
printed_params = set()
|
||||
|
||||
# Print DQN layers in groups to avoid output truncation
|
||||
print(f"\nDQN Model Layers (Feature Extraction): {feature_extraction_params:,} parameters")
|
||||
for name, (count, shape) in dqn_layer_params.items():
|
||||
if "feature_extraction" in name:
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
printed_params.add(name)
|
||||
|
||||
print(f"\nDQN Model Layers (Advantage & Value Streams): {advantage_value_params:,} parameters")
|
||||
for name, (count, shape) in dqn_layer_params.items():
|
||||
if "advantage_stream" in name or "value_stream" in name:
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
printed_params.add(name)
|
||||
|
||||
print(f"\nDQN Model Layers (Transformer): {transformer_params:,} parameters")
|
||||
for name, (count, shape) in dqn_layer_params.items():
|
||||
if "transformer" in name:
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
printed_params.add(name)
|
||||
|
||||
print(f"\nDQN Model Layers (LSTM): {lstm_params:,} parameters")
|
||||
for name, (count, shape) in dqn_layer_params.items():
|
||||
if "lstm" in name and "transformer" not in name:
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
printed_params.add(name)
|
||||
|
||||
print(f"\nDQN Model Layers (Final Layers): {final_layers_params:,} parameters")
|
||||
for name, (count, shape) in dqn_layer_params.items():
|
||||
if "final_layers" in name:
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
printed_params.add(name)
|
||||
|
||||
# Print any remaining parameters that weren't caught by the categories above
|
||||
remaining_params = set(dqn_layer_params.keys()) - printed_params
|
||||
if remaining_params:
|
||||
remaining_params_count = sum(dqn_layer_params[name][0] for name in remaining_params)
|
||||
print(f"\nDQN Model Layers (Other): {remaining_params_count:,} parameters")
|
||||
for name in remaining_params:
|
||||
count, shape = dqn_layer_params[name]
|
||||
print(f"{name}: {count:,} (shape: {shape})")
|
||||
|
||||
# Total parameters across both models
|
||||
print(f"\nTotal parameters (both models): {price_total_params + dqn_total_params:,}")
|
||||
|
||||
# Print summary of parameter distribution
|
||||
print("\nParameter Distribution Summary:")
|
||||
print(f"Price Prediction Model: {price_total_params:,} parameters ({price_total_params/(price_total_params + dqn_total_params)*100:.1f}%)")
|
||||
print(f"DQN Model: {dqn_total_params:,} parameters ({dqn_total_params/(price_total_params + dqn_total_params)*100:.1f}%)")
|
||||
print("\nDQN Model Breakdown:")
|
||||
print(f"- Feature Extraction: {feature_extraction_params:,} parameters ({feature_extraction_params/dqn_total_params*100:.1f}%)")
|
||||
print(f"- Advantage & Value Streams: {advantage_value_params:,} parameters ({advantage_value_params/dqn_total_params*100:.1f}%)")
|
||||
print(f"- Transformer: {transformer_params:,} parameters ({transformer_params/dqn_total_params*100:.1f}%)")
|
||||
print(f"- LSTM: {lstm_params:,} parameters ({lstm_params/dqn_total_params*100:.1f}%)")
|
||||
print(f"- Final Layers: {final_layers_params:,} parameters ({final_layers_params/dqn_total_params*100:.1f}%)")
|
||||
|
||||
# Verify that all parameters are accounted for
|
||||
total_by_category = feature_extraction_params + advantage_value_params + transformer_params + lstm_params + final_layers_params
|
||||
if remaining_params:
|
||||
total_by_category += remaining_params_count
|
||||
print(f"\nSum of all categories: {total_by_category:,} parameters")
|
||||
print(f"Difference from total: {dqn_total_params - total_by_category:,} parameters")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
42
crypto/gogo2/fix_indentation.py
Normal file
42
crypto/gogo2/fix_indentation.py
Normal file
@ -0,0 +1,42 @@
|
||||
def fix_indentation():
|
||||
with open('main.py', 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Fix indentation for the problematic sections
|
||||
fixed_lines = []
|
||||
|
||||
# Find the try block that starts at line 1693
|
||||
try_start_line = 1693
|
||||
try_block_found = False
|
||||
in_try_block = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# Check if we're at the try statement
|
||||
if i+1 == try_start_line and 'try:' in line:
|
||||
try_block_found = True
|
||||
in_try_block = True
|
||||
fixed_lines.append(line)
|
||||
# Fix the indentation of the experiences line
|
||||
elif i+1 == 1695 and line.strip().startswith('experiences = self.memory.sample(BATCH_SIZE)'):
|
||||
# Add proper indentation (4 spaces)
|
||||
fixed_lines.append(' ' + line.lstrip())
|
||||
# Check if we're at the end of the try block without an except
|
||||
elif try_block_found and in_try_block and i+1 > try_start_line and line.strip() and not line.startswith(' '):
|
||||
# We've reached the end of the try block without an except, add one
|
||||
fixed_lines.append(' except Exception as e:\n')
|
||||
fixed_lines.append(' logger.error(f"Error during learning: {e}")\n')
|
||||
fixed_lines.append(' logger.error(f"Traceback: {traceback.format_exc()}")\n')
|
||||
fixed_lines.append(' return None\n\n')
|
||||
in_try_block = False
|
||||
fixed_lines.append(line)
|
||||
else:
|
||||
fixed_lines.append(line)
|
||||
|
||||
# Write the fixed content back to the file
|
||||
with open('main.py', 'w') as f:
|
||||
f.writelines(fixed_lines)
|
||||
|
||||
print("Indentation fixed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix_indentation()
|
22
crypto/gogo2/fix_try_blocks.py
Normal file
22
crypto/gogo2/fix_try_blocks.py
Normal file
@ -0,0 +1,22 @@
|
||||
import re
|
||||
|
||||
def fix_try_blocks():
|
||||
with open('main.py', 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find all try blocks without except or finally
|
||||
pattern = r'(\s+)try:\s*\n((?:\1\s+.*\n)+?)(?!\1\s*except|\1\s*finally)'
|
||||
|
||||
# Replace with try-except blocks
|
||||
fixed_content = re.sub(pattern,
|
||||
r'\1try:\n\2\1except Exception as e:\n\1 logger.error(f"Error: {e}")\n\1 logger.error(f"Traceback: {traceback.format_exc()}")\n\1 return None\n\n',
|
||||
content)
|
||||
|
||||
# Write the fixed content back to the file
|
||||
with open('main.py', 'w') as f:
|
||||
f.write(fixed_content)
|
||||
|
||||
print("Try blocks fixed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix_try_blocks()
|
@ -1151,10 +1151,10 @@ class TradingEnvironment:
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||
self.win_count += 1
|
||||
else:
|
||||
reward = -1.0 # Negative reward for loss
|
||||
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
@ -1238,10 +1238,15 @@ class TradingEnvironment:
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||
self.win_count += 1
|
||||
|
||||
# Extra reward for closing at a predicted high
|
||||
if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
|
||||
reward += 1.0
|
||||
logger.info("Closing long at predicted high - additional reward")
|
||||
else:
|
||||
reward = -1.0 # Negative reward for loss
|
||||
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
@ -1296,15 +1301,15 @@ class TradingEnvironment:
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||
self.win_count += 1
|
||||
|
||||
# Extra reward for closing at a predicted high
|
||||
if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
|
||||
reward += 0.5
|
||||
reward += 1.0
|
||||
logger.info("Closing long at predicted high - additional reward")
|
||||
else:
|
||||
reward = -1.0 # Negative reward for loss
|
||||
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
@ -1354,15 +1359,15 @@ class TradingEnvironment:
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
|
||||
self.win_count += 1
|
||||
|
||||
# Extra reward for closing at a predicted low
|
||||
if hasattr(self, 'has_predicted_low') and self.has_predicted_low:
|
||||
reward += 0.5
|
||||
reward += 1.0
|
||||
logger.info("Closing short at predicted low - additional reward")
|
||||
else:
|
||||
reward = -1.0 # Negative reward for loss
|
||||
reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
@ -1378,9 +1383,9 @@ class TradingEnvironment:
|
||||
# Add reward based on direct PnL change
|
||||
balance_change = self.balance - prev_balance
|
||||
if balance_change > 0:
|
||||
reward += balance_change * 0.5 # Positive reward for making money
|
||||
reward += balance_change * 1.0 # Increased positive reward for making money
|
||||
else:
|
||||
reward += balance_change * 1.0 # Stronger negative reward for losing money
|
||||
reward += balance_change * 2.0 # Stronger negative reward for losing money
|
||||
|
||||
# Add reward for predicted price movement alignment
|
||||
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
||||
@ -1611,9 +1616,15 @@ class TradingEnvironment:
|
||||
|
||||
def initialize_price_predictor(self, device="cpu"):
|
||||
"""Initialize the price prediction model"""
|
||||
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
||||
self.price_predictor.to(device)
|
||||
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
||||
# Only create a new model if one doesn't already exist
|
||||
if not hasattr(self, 'price_predictor') or self.price_predictor is None:
|
||||
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
||||
self.price_predictor.to(device)
|
||||
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
|
||||
else:
|
||||
# If model exists, just ensure it's on the right device
|
||||
self.price_predictor.to(device)
|
||||
|
||||
self.predicted_prices = np.array([])
|
||||
self.predicted_extrema = np.array([])
|
||||
self.extrema_threshold = 0.7 # Threshold for extrema prediction confidence
|
||||
@ -1766,16 +1777,16 @@ class TradingEnvironment:
|
||||
return fee
|
||||
|
||||
# Ensure GPU usage if available
|
||||
def get_device(device_preference='gpu'):
|
||||
def get_device(device_preference='auto'):
|
||||
"""Get the device to use (GPU or CPU) based on preference and availability"""
|
||||
if device_preference.lower() == 'gpu' and torch.cuda.is_available():
|
||||
if device_preference.lower() in ['gpu', 'auto'] and torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# Set default tensor type to float32 for CUDA
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
if device_preference.lower() == 'gpu':
|
||||
if device_preference.lower() in ['gpu', 'auto']:
|
||||
logger.info("GPU requested but not available, using CPU instead")
|
||||
else:
|
||||
logger.info("Using CPU as requested")
|
||||
@ -1952,7 +1963,7 @@ class Agent:
|
||||
|
||||
# Use mixed precision for forward/backward passes
|
||||
if self.device.type == "cuda":
|
||||
with amp.autocast():
|
||||
with amp.autocast(device_type='cuda'):
|
||||
# Compute Q values
|
||||
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
||||
|
||||
@ -2943,10 +2954,12 @@ async def main():
|
||||
import traceback
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run the trading bot')
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'], help='Mode to run the bot in')
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live', 'continuous'], help='Mode to run the bot in')
|
||||
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train for')
|
||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||
parser.add_argument('--device', type=str, default='auto', choices=['cpu', 'gpu', 'auto'], help='Device to use for training')
|
||||
parser.add_argument('--refresh-data', '--refresh_data', dest='refresh_data', action='store_true', help='Refresh data at the start of each episode')
|
||||
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for data (e.g., 1s, 1m, 5m, 15m, 1h)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device
|
||||
@ -2995,6 +3008,247 @@ async def main():
|
||||
results = evaluate_agent(agent, env, num_episodes=10)
|
||||
logger.info(f"Evaluation results: {results}")
|
||||
|
||||
elif args.mode == 'continuous':
|
||||
# Continuous training mode - train indefinitely with data refreshing
|
||||
logger.info("Starting continuous training mode...")
|
||||
|
||||
# Set refresh_data to True for continuous mode
|
||||
args.refresh_data = True
|
||||
|
||||
# Create directories for continuous models
|
||||
os.makedirs("models", exist_ok=True)
|
||||
|
||||
# Track best PnL for model selection
|
||||
best_pnl = float('-inf')
|
||||
best_pnl_model_path = "models/trading_agent_best_pnl.pt"
|
||||
|
||||
# Load the best PnL model if it exists
|
||||
if os.path.exists(best_pnl_model_path):
|
||||
logger.info(f"Loading best PnL model: {best_pnl_model_path}")
|
||||
agent.load(best_pnl_model_path)
|
||||
|
||||
# Try to load best PnL value from checkpoint file
|
||||
checkpoint_info_path = "checkpoints/best_metrics.json"
|
||||
if os.path.exists(checkpoint_info_path):
|
||||
with open(checkpoint_info_path, 'r') as f:
|
||||
best_metrics = json.load(f)
|
||||
best_pnl = best_metrics.get('best_pnl', best_pnl)
|
||||
logger.info(f"Loaded best PnL from checkpoint: ${best_pnl:.2f}")
|
||||
|
||||
# Initialize episode counter
|
||||
episode = 0
|
||||
|
||||
# Get timeframe from args
|
||||
timeframe = args.timeframe
|
||||
logger.info(f"Using timeframe: {timeframe}")
|
||||
|
||||
# Initialize TensorBoard writer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
tensorboard_dir = f"runs/continuous_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
writer = SummaryWriter(tensorboard_dir)
|
||||
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
|
||||
|
||||
# Attach writer to agent
|
||||
agent.writer = writer
|
||||
|
||||
# Initialize stats dictionary for plotting
|
||||
stats = {
|
||||
'episode_rewards': [],
|
||||
'episode_profits': [],
|
||||
'win_rates': [],
|
||||
'trade_counts': [],
|
||||
'prediction_accuracies': []
|
||||
}
|
||||
|
||||
# Train continuously
|
||||
try:
|
||||
while True:
|
||||
logger.info(f"Continuous training - Episode {episode}")
|
||||
|
||||
# Refresh data from exchange with the specified timeframe
|
||||
logger.info(f"Refreshing market data with timeframe {timeframe}...")
|
||||
await env.fetch_new_data(exchange, "ETH/USDT", timeframe, 100)
|
||||
|
||||
# Reset environment
|
||||
state = env.reset()
|
||||
|
||||
# Initialize price predictor if not already initialized
|
||||
if not hasattr(env, 'price_predictor') or env.price_predictor is None:
|
||||
logger.info("Initializing price predictor...")
|
||||
env.initialize_price_predictor(device=agent.device)
|
||||
|
||||
# Initialize episode variables
|
||||
episode_reward = 0
|
||||
done = False
|
||||
|
||||
# Train price predictor
|
||||
prediction_loss, extrema_loss = env.train_price_predictor()
|
||||
|
||||
# Update price predictions
|
||||
env.update_price_predictions()
|
||||
|
||||
# Training loop for this episode
|
||||
while not done:
|
||||
# Select action
|
||||
action = agent.select_action(state)
|
||||
|
||||
# Take action
|
||||
next_state, reward, done = env.step(action)
|
||||
|
||||
# Store experience
|
||||
agent.memory.push(state, action, reward, next_state, done)
|
||||
|
||||
# Learn from experience
|
||||
loss = agent.learn()
|
||||
|
||||
# Update state and reward
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
|
||||
# Calculate win rate
|
||||
total_trades = env.win_count + env.loss_count
|
||||
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
|
||||
|
||||
# Calculate prediction accuracy
|
||||
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
||||
# Compare predictions with actual prices
|
||||
actual_prices = env.features['price'][-len(env.predicted_prices):]
|
||||
prediction_errors = np.abs(env.predicted_prices - actual_prices) / actual_prices
|
||||
prediction_accuracy = 100 * (1 - np.mean(prediction_errors))
|
||||
else:
|
||||
prediction_accuracy = 0
|
||||
|
||||
# Update stats
|
||||
stats['episode_rewards'].append(episode_reward)
|
||||
stats['episode_profits'].append(env.episode_pnl)
|
||||
stats['win_rates'].append(win_rate)
|
||||
stats['trade_counts'].append(total_trades)
|
||||
stats['prediction_accuracies'].append(prediction_accuracy)
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('Reward/continuous', episode_reward, episode)
|
||||
writer.add_scalar('Balance/continuous', env.balance, episode)
|
||||
writer.add_scalar('WinRate/continuous', win_rate, episode)
|
||||
writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
||||
writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
||||
writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
||||
writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
||||
writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
|
||||
|
||||
# Log OHLCV data to TensorBoard every 5 episodes
|
||||
if episode % 5 == 0:
|
||||
# Create a DataFrame from the environment's data
|
||||
df_ohlcv = pd.DataFrame([{
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
} for candle in env.data[-100:]]) # Use last 100 candles
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
|
||||
df_ohlcv.set_index('timestamp', inplace=True)
|
||||
|
||||
# Extract buy/sell signals from trades
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
for trade in env.trades:
|
||||
if 'entry_time' in trade and 'entry' in trade:
|
||||
if trade['type'] == 'long':
|
||||
# Buy signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
buy_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Sell signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
sell_signals.append((exit_time, trade['exit']))
|
||||
|
||||
elif trade['type'] == 'short':
|
||||
# Sell short signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
sell_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Buy to cover signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
buy_signals.append((exit_time, trade['exit']))
|
||||
|
||||
# Log to TensorBoard
|
||||
log_ohlcv_to_tensorboard(
|
||||
writer,
|
||||
df_ohlcv,
|
||||
buy_signals,
|
||||
sell_signals,
|
||||
episode,
|
||||
tag_prefix=f"continuous_episode_{episode}"
|
||||
)
|
||||
|
||||
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
||||
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
||||
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}")
|
||||
|
||||
# Create visualization every 10 episodes
|
||||
if episode % 10 == 0:
|
||||
# Create visualization
|
||||
os.makedirs("visualizations", exist_ok=True)
|
||||
visualize_training_results(env, agent, episode)
|
||||
|
||||
# Save model
|
||||
model_path = f"models/trading_agent_continuous_{episode}.pt"
|
||||
agent.save(model_path)
|
||||
logger.info(f"Saved continuous model: {model_path}")
|
||||
|
||||
# Plot training results
|
||||
plot_training_results(stats)
|
||||
|
||||
# Save best PnL model
|
||||
if env.episode_pnl > best_pnl:
|
||||
best_pnl = env.episode_pnl
|
||||
agent.save(best_pnl_model_path)
|
||||
logger.info(f"New best PnL model saved: ${env.episode_pnl:.2f}")
|
||||
|
||||
# Save best metrics to resume training if interrupted
|
||||
best_metrics = {
|
||||
'best_pnl': float(best_pnl),
|
||||
'last_episode': episode,
|
||||
'timestamp': datetime.datetime.now().isoformat()
|
||||
}
|
||||
os.makedirs("checkpoints", exist_ok=True)
|
||||
with open("checkpoints/best_metrics.json", 'w') as f:
|
||||
json.dump(best_metrics, f)
|
||||
|
||||
# Update target network
|
||||
agent.update_target_network()
|
||||
|
||||
# Increment episode counter
|
||||
episode += 1
|
||||
|
||||
# Sleep briefly to prevent overwhelming the system
|
||||
# Use shorter sleep for shorter timeframes
|
||||
if timeframe.endswith('s'):
|
||||
await asyncio.sleep(0.1) # Very short sleep for second-based timeframes
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Continuous training stopped by user")
|
||||
# Save final model
|
||||
agent.save("models/trading_agent_continuous_final.pt")
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous training: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
# Save emergency model
|
||||
agent.save(f"models/trading_agent_continuous_emergency_{episode}.pt")
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
# Load the best model
|
||||
agent.load("models/trading_agent_best_pnl.pt")
|
||||
|
304
crypto/gogo2/mexc_trading.py
Normal file
304
crypto/gogo2/mexc_trading.py
Normal file
@ -0,0 +1,304 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from mexc_api.spot import Spot
|
||||
from mexc_api.common.enums import Side, OrderType
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('mexc_trading')
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
||||
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
class MexcTradingClient:
|
||||
"""Client for executing trades on MEXC exchange using the official API"""
|
||||
|
||||
def __init__(self, api_key=None, api_secret=None, symbol="ETH/USDT", leverage=1):
|
||||
"""Initialize the MEXC trading client"""
|
||||
self.api_key = api_key or MEXC_API_KEY
|
||||
self.api_secret = api_secret or MEXC_SECRET_KEY
|
||||
|
||||
# Ensure API keys are not None
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.warning("API keys not provided. Using empty strings for public endpoints only.")
|
||||
self.api_key = ""
|
||||
self.api_secret = ""
|
||||
|
||||
self.symbol = symbol
|
||||
self.formatted_symbol = symbol.replace('/', '') # MEXC requires no slash
|
||||
self.leverage = leverage
|
||||
self.client = None
|
||||
self.position = 'flat' # 'flat', 'long', or 'short'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.open_orders = []
|
||||
self.order_history = []
|
||||
self.trades = []
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
|
||||
# Initialize the MEXC API client
|
||||
self.initialize_client()
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the MEXC API client using the API"""
|
||||
try:
|
||||
self.client = Spot(self.api_key, self.api_secret)
|
||||
# Test connection
|
||||
server_time = self.client.market.server_time()
|
||||
logger.info(f"MEXC API client initialized successfully. Server time: {server_time}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MEXC API client: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def fetch_account_balance(self):
|
||||
"""Fetch account balance from MEXC API"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot fetch account balance.")
|
||||
return 0
|
||||
|
||||
account_info = self.client.account.account_info()
|
||||
if 'balances' in account_info:
|
||||
# Find USDT balance
|
||||
for asset in account_info['balances']:
|
||||
if asset['asset'] == 'USDT':
|
||||
return float(asset['free'])
|
||||
|
||||
logger.warning("Could not find USDT balance")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching account balance: {e}")
|
||||
return 0
|
||||
|
||||
async def fetch_open_positions(self):
|
||||
"""Fetch open positions from MEXC API"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot fetch open positions.")
|
||||
return []
|
||||
|
||||
# Fetch open orders
|
||||
open_orders = self.client.account.open_orders(self.formatted_symbol)
|
||||
return open_orders
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching open positions: {e}")
|
||||
return []
|
||||
|
||||
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
|
||||
"""Open a new position using MEXC API"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot open position.")
|
||||
return False
|
||||
|
||||
# Calculate quantity based on size and price
|
||||
quantity = size / entry_price
|
||||
# Round quantity to appropriate precision
|
||||
quantity = round(quantity, 4) # Adjust precision as needed for your asset
|
||||
|
||||
# Determine order side
|
||||
side = Side.BUY if position_type == 'long' else Side.SELL
|
||||
|
||||
logger.info(f"Opening {position_type} position: {quantity} {self.symbol} at market price")
|
||||
|
||||
# Place market order
|
||||
order_result = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
side,
|
||||
OrderType.MARKET,
|
||||
str(quantity)
|
||||
)
|
||||
|
||||
logger.info(f"Market order result: {order_result}")
|
||||
|
||||
# Check if order was filled
|
||||
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
|
||||
# Get actual entry price
|
||||
actual_entry_price = float(order_result.get('price', entry_price))
|
||||
|
||||
# Place stop loss order
|
||||
sl_side = Side.SELL if position_type == 'long' else Side.BUY
|
||||
sl_order = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
sl_side,
|
||||
OrderType.STOP_LOSS_LIMIT,
|
||||
str(quantity),
|
||||
price=str(stop_loss),
|
||||
stop_price=str(stop_loss),
|
||||
time_in_force="GTC"
|
||||
)
|
||||
|
||||
logger.info(f"Stop loss order placed: {sl_order}")
|
||||
|
||||
# Place take profit order
|
||||
tp_side = Side.SELL if position_type == 'long' else Side.BUY
|
||||
tp_order = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
tp_side,
|
||||
OrderType.TAKE_PROFIT_LIMIT,
|
||||
str(quantity),
|
||||
price=str(take_profit),
|
||||
stop_price=str(take_profit),
|
||||
time_in_force="GTC"
|
||||
)
|
||||
|
||||
logger.info(f"Take profit order placed: {tp_order}")
|
||||
|
||||
# Update local state
|
||||
self.position = position_type
|
||||
self.position_size = size
|
||||
self.entry_price = actual_entry_price
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
# Track orders
|
||||
self.open_orders.extend([sl_order, tp_order])
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"Successfully opened {position_type} position at {actual_entry_price}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to open position: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def close_position(self, reason="manual_close"):
|
||||
"""Close an existing position"""
|
||||
if self.position == 'flat':
|
||||
logger.info("No position to close")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot close position.")
|
||||
return False
|
||||
|
||||
# First, cancel any existing stop loss/take profit orders
|
||||
try:
|
||||
self.client.account.cancel_open_orders(self.formatted_symbol)
|
||||
logger.info("Canceled all open orders")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error canceling open orders: {e}")
|
||||
|
||||
# Determine order side (opposite of position)
|
||||
side = Side.SELL if self.position == 'long' else Side.BUY
|
||||
|
||||
# Calculate quantity
|
||||
quantity = self.position_size / self.entry_price
|
||||
# Round quantity to appropriate precision
|
||||
quantity = round(quantity, 4) # Adjust precision as needed
|
||||
|
||||
logger.info(f"Closing {self.position} position: {quantity} {self.symbol} at market price")
|
||||
|
||||
# Execute market order to close position
|
||||
order_result = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
side,
|
||||
OrderType.MARKET,
|
||||
str(quantity)
|
||||
)
|
||||
|
||||
logger.info(f"Close order result: {order_result}")
|
||||
|
||||
# Check if order was filled
|
||||
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
|
||||
# Get actual exit price
|
||||
exit_price = float(order_result.get('price', 0))
|
||||
|
||||
# Calculate PnL
|
||||
if self.position == 'long':
|
||||
pnl_percent = (exit_price - self.entry_price) / self.entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (self.entry_price - exit_price) / self.entry_price * 100
|
||||
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': self.position,
|
||||
'entry': self.entry_price,
|
||||
'exit': exit_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': reason,
|
||||
'order_id': order_result.get('orderId')
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Track order history
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"Closed {self.position} position at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to close position: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def check_order_status(self, order_id):
|
||||
"""Check the status of a specific order"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot check order status.")
|
||||
return None
|
||||
|
||||
order_status = self.client.account.query_order(
|
||||
self.formatted_symbol,
|
||||
order_id=order_id
|
||||
)
|
||||
return order_status
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking order status: {e}")
|
||||
return None
|
||||
|
||||
async def get_market_price(self):
|
||||
"""Get current market price for the symbol"""
|
||||
try:
|
||||
ticker = self.client.market.ticker_price(self.formatted_symbol)
|
||||
if isinstance(ticker, list) and len(ticker) > 0 and 'price' in ticker[0]:
|
||||
return float(ticker[0]['price'])
|
||||
elif isinstance(ticker, dict) and 'price' in ticker:
|
||||
return float(ticker['price'])
|
||||
else:
|
||||
logger.error(f"Unexpected ticker format: {ticker}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market price: {e}")
|
||||
return None
|
456
crypto/gogo2/mexcapi-readme.md
Normal file
456
crypto/gogo2/mexcapi-readme.md
Normal file
@ -0,0 +1,456 @@
|
||||
# mexc-api-sdk
|
||||
|
||||
MEXC Official Market and trade api sdk, easy to connection and send request to MEXC open api !
|
||||
|
||||
## Prerequisites
|
||||
- To use our SDK you have to install nodejs LTS (https://aws.github.io/jsii/user-guides/lib-user/)
|
||||
|
||||
## Installation
|
||||
1.
|
||||
```
|
||||
git clone https://github.com/mxcdevelop/mexc-api-sdk.git
|
||||
```
|
||||
2. cd dist/{language} and unzip the file
|
||||
3. we offer five language : dotnet, go, java, js, python
|
||||
|
||||
## Table of APIS
|
||||
- [Init](#init)
|
||||
- [Market](#market)
|
||||
- [Ping](#ping)
|
||||
- [Check Server Time](#check-server-time)
|
||||
- [Exchange Information](#exchange-information)
|
||||
- [Recent Trades List](#recent-trades-list)
|
||||
- [Order Book](#order-book)
|
||||
- [Old Trade Lookup](#old-trade-lookup)
|
||||
- [Aggregate Trades List](#aggregate-trades-list)
|
||||
- [kline Data](#kline-data)
|
||||
- [Current Average Price](#current-average-price)
|
||||
- [24hr Ticker Price Change Statistics](#24hr-ticker-price-change-statistics)
|
||||
- [Symbol Price Ticker](#symbol-price-ticker)
|
||||
- [Symbol Order Book Ticker](#symbol-order-book-ticker)
|
||||
- [Trade](#trade)
|
||||
- [Test New Order](#test-new-order)
|
||||
- [New Order](#new-order)
|
||||
- [cancel-order](#cancel-order)
|
||||
- [Cancel all Open Orders on a Symbol](#cancel-all-open-orders-on-a-symbol)
|
||||
- [Query Order](#query-order)
|
||||
- [Current Open Orders](#current-open-orders)
|
||||
- [All Orders](#all-orders)
|
||||
- [Account Information](#account-information)
|
||||
- [Account Trade List](#account-trade-list)
|
||||
## Init
|
||||
```javascript
|
||||
//Javascript
|
||||
import * as Mexc from 'mexc-sdk';
|
||||
const apiKey = 'apiKey'
|
||||
const apiSecret = 'apiSecret'
|
||||
const client = new Mexc.Spot(apiKey, apiSecret);
|
||||
```
|
||||
```go
|
||||
// Go
|
||||
package main
|
||||
import (
|
||||
"fmt"
|
||||
"mexc-sdk/mexcsdk"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiKey := "apiKey"
|
||||
apiSecret := "apiSecret"
|
||||
spot := mexcsdk.NewSpot(apiKey, apiSecret)
|
||||
}
|
||||
```
|
||||
```python
|
||||
# python
|
||||
from mexc_sdk import Spot
|
||||
spot = Spot(api_key='apiKey', api_secret='apiSecret')
|
||||
```
|
||||
```java
|
||||
// java
|
||||
import Mexc.Sdk.*;
|
||||
class MyClass {
|
||||
public static void main(String[] args) {
|
||||
String apiKey= "apiKey";
|
||||
String apiSecret= "apiSecret";
|
||||
Spot mySpot = new Spot(apiKey, apiSecret);
|
||||
}
|
||||
}
|
||||
```
|
||||
```C#
|
||||
// dotnet
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Mxc.Sdk;
|
||||
|
||||
namespace dotnet
|
||||
{
|
||||
class Program
|
||||
{
|
||||
static void Main(string[] args)
|
||||
{
|
||||
string apiKey = "apiKey";
|
||||
string apiSecret= "apiSecret";
|
||||
var spot = new Spot(apiKey, apiSecret);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Market
|
||||
### Ping
|
||||
```javascript
|
||||
client.ping()
|
||||
```
|
||||
### Check Server Time
|
||||
```javascript
|
||||
client.time()
|
||||
```
|
||||
### Exchange Information
|
||||
```javascript
|
||||
client.exchangeInfo(options: any)
|
||||
options:{symbol, symbols}
|
||||
/**
|
||||
* choose one parameter
|
||||
*
|
||||
* symbol :
|
||||
* example "BNBBTC";
|
||||
*
|
||||
* symbols :
|
||||
* array of symbol
|
||||
* example ["BTCUSDT","BNBBTC"];
|
||||
*
|
||||
*/
|
||||
```
|
||||
|
||||
### Recent Trades List
|
||||
```javascript
|
||||
client.trades(symbol: string, options: any = { limit: 500 })
|
||||
options:{limit}
|
||||
/**
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
### Order Book
|
||||
```javascript
|
||||
client.depth(symbol: string, options: any = { limit: 100 })
|
||||
options:{limit}
|
||||
/**
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 100;
|
||||
* max 5000;
|
||||
* Valid:[5, 10, 20, 50, 100, 500, 1000, 5000]
|
||||
*
|
||||
*/
|
||||
```
|
||||
|
||||
### Old Trade Lookup
|
||||
```javascript
|
||||
client.historicalTrades(symbol: string, options: any = { limit: 500 })
|
||||
options:{limit, fromId}
|
||||
/**
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
* fromId:
|
||||
* Trade id to fetch from. Default gets most recent trades
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### Aggregate Trades List
|
||||
```javascript
|
||||
client.aggTrades(symbol: string, options: any = { limit: 500 })
|
||||
options:{fromId, startTime, endTime, limit}
|
||||
/**
|
||||
*
|
||||
* fromId :
|
||||
* id to get aggregate trades from INCLUSIVE
|
||||
*
|
||||
* startTime:
|
||||
* start at
|
||||
*
|
||||
* endTime:
|
||||
* end at
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
### kline Data
|
||||
```javascript
|
||||
client.klines(symbol: string, interval: string, options: any = { limit: 500 })
|
||||
options:{ startTime, endTime, limit}
|
||||
/**
|
||||
*
|
||||
* interval :
|
||||
* m :minute;
|
||||
* h :Hour;
|
||||
* d :day;
|
||||
* w :week;
|
||||
* M :month
|
||||
* example : "1m"
|
||||
*
|
||||
* startTime :
|
||||
* start at
|
||||
*
|
||||
* endTime :
|
||||
* end at
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
|
||||
### Current Average Price
|
||||
```javascript
|
||||
client.avgPrice(symbol: string)
|
||||
```
|
||||
### 24hr Ticker Price Change Statistics
|
||||
```javascript
|
||||
client.ticker24hr(symbol?: string)
|
||||
```
|
||||
### Symbol Price Ticker
|
||||
```javascript
|
||||
client.tickerPrice(symbol?: string)
|
||||
```
|
||||
|
||||
### Symbol Order Book Ticker
|
||||
```javascript
|
||||
client.bookTicker(symbol?: string)
|
||||
```
|
||||
## Trade
|
||||
### Test New Order
|
||||
```javascript
|
||||
client.newOrderTest(symbol: string, side: string, orderType: string, options: any = {})
|
||||
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
|
||||
/**
|
||||
*
|
||||
* side:
|
||||
* Order side
|
||||
* ENUM:
|
||||
* BUY
|
||||
* SELL
|
||||
*
|
||||
* orderType:
|
||||
* Order type
|
||||
* ENUM:
|
||||
* LIMIT
|
||||
* MARKET
|
||||
* STOP_LOSS
|
||||
* STOP_LOSS_LIMIT
|
||||
* TAKE_PROFIT
|
||||
* TAKE_PROFIT_LIMIT
|
||||
* LIMIT_MAKER
|
||||
*
|
||||
* timeInForce :
|
||||
* How long an order will be active before expiration.
|
||||
* GTC: Active unless the order is canceled
|
||||
* IOC: Order will try to fill the order as much as it can before the order expires
|
||||
* FOK: Active unless the full order cannot be filled upon execution.
|
||||
*
|
||||
* quantity :
|
||||
* target quantity
|
||||
*
|
||||
* quoteOrderQty :
|
||||
* Specify the total spent or received
|
||||
*
|
||||
* price :
|
||||
* target price
|
||||
*
|
||||
* newClientOrderId :
|
||||
* A unique id among open orders. Automatically generated if not sent
|
||||
*
|
||||
* stopPrice :
|
||||
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
|
||||
*
|
||||
* icebergQty :
|
||||
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
|
||||
*
|
||||
* newOrderRespType :
|
||||
* Set the response JSON. ACK, RESULT, or FULL;
|
||||
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
|
||||
*
|
||||
* recvWindow :
|
||||
* Delay accept time
|
||||
* The value cannot be greater than 60000
|
||||
* defaults: 5000
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### New Order
|
||||
```javascript
|
||||
client.newOrder(symbol: string, side: string, orderType: string, options: any = {})
|
||||
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
|
||||
/**
|
||||
*
|
||||
* side:
|
||||
* Order side
|
||||
* ENUM:
|
||||
* BUY
|
||||
* SELL
|
||||
*
|
||||
* orderType:
|
||||
* Order type
|
||||
* ENUM:
|
||||
* LIMIT
|
||||
* MARKET
|
||||
* STOP_LOSS
|
||||
* STOP_LOSS_LIMIT
|
||||
* TAKE_PROFIT
|
||||
* TAKE_PROFIT_LIMIT
|
||||
* LIMIT_MAKER
|
||||
*
|
||||
* timeInForce :
|
||||
* How long an order will be active before expiration.
|
||||
* GTC: Active unless the order is canceled
|
||||
* IOC: Order will try to fill the order as much as it can before the order expires
|
||||
* FOK: Active unless the full order cannot be filled upon execution.
|
||||
*
|
||||
* quantity :
|
||||
* target quantity
|
||||
*
|
||||
* quoteOrderQty :
|
||||
* Specify the total spent or received
|
||||
*
|
||||
* price :
|
||||
* target price
|
||||
*
|
||||
* newClientOrderId :
|
||||
* A unique id among open orders. Automatically generated if not sent
|
||||
*
|
||||
* stopPrice :
|
||||
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
|
||||
*
|
||||
* icebergQty :
|
||||
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
|
||||
*
|
||||
* newOrderRespType :
|
||||
* Set the response JSON. ACK, RESULT, or FULL;
|
||||
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
|
||||
*
|
||||
* recvWindow :
|
||||
* Delay accept time
|
||||
* The value cannot be greater than 60000
|
||||
* defaults: 5000
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### cancel-order
|
||||
```javascript
|
||||
client.cancelOrder(symbol: string, options:any = {})
|
||||
options:{ orderId, origClientOrderId, newClientOrderId}
|
||||
/**
|
||||
*
|
||||
* Either orderId or origClientOrderId must be sent
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* origClientOrderId:
|
||||
* target origClientOrderId
|
||||
*
|
||||
* newClientOrderId:
|
||||
* Used to uniquely identify this cancel. Automatically generated by default.
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### Cancel all Open Orders on a Symbol
|
||||
```javascript
|
||||
client.cancelOpenOrders(symbol: string)
|
||||
```
|
||||
### Query Order
|
||||
```javascript
|
||||
client.queryOrder(symbol: string, options:any = {})
|
||||
options:{ orderId, origClientOrderId}
|
||||
/**
|
||||
*
|
||||
* Either orderId or origClientOrderId must be sent
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* origClientOrderId:
|
||||
* target origClientOrderId
|
||||
*
|
||||
*/
|
||||
```
|
||||
### Current Open Orders
|
||||
```javascript
|
||||
client.openOrders(symbol: string)
|
||||
```
|
||||
### All Orders
|
||||
```javascript
|
||||
client.allOrders(symbol: string, options: any = { limit: 500 })
|
||||
options:{ orderId, startTime, endTime, limit}
|
||||
|
||||
/**
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* startTime:
|
||||
* start at
|
||||
*
|
||||
* endTime:
|
||||
* end at
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
### Account Information
|
||||
```javascript
|
||||
client.accountInfo()
|
||||
```
|
||||
### Account Trade List
|
||||
```javascript
|
||||
client.accountTradeList(symbol: string, options:any = { limit: 500 })
|
||||
options:{ orderId, startTime, endTime, fromId, limit}
|
||||
|
||||
/**
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* startTime:
|
||||
* start at
|
||||
*
|
||||
* endTime:
|
||||
* end at
|
||||
*
|
||||
* fromId:
|
||||
* TradeId to fetch from. Default gets most recent trades
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
@ -1,188 +1,91 @@
|
||||
# Crypto Trading Bot with Reinforcement Learning
|
||||
# Crypto Trading Bot with MEXC API Integration
|
||||
|
||||
An automated cryptocurrency trading bot that uses Deep Q-Learning (DQN) to trade ETH/USDT on the MEXC exchange. The bot features a sophisticated neural network architecture with LSTM layers and attention mechanisms for better pattern recognition.
|
||||
This is an AI-powered cryptocurrency trading bot that can run in both simulation (demo) mode and live trading mode using the MEXC exchange API.
|
||||
|
||||
## Features
|
||||
|
||||
- Deep Q-Learning with experience replay
|
||||
- LSTM layers for sequential data processing
|
||||
- Multi-head attention mechanism
|
||||
- Dueling DQN architecture
|
||||
- Real-time trading capabilities
|
||||
- TensorBoard integration for monitoring
|
||||
- Comprehensive technical indicators
|
||||
- Demo and live trading modes
|
||||
- Automatic model checkpointing
|
||||
- Deep Reinforcement Learning agent for trading decisions
|
||||
- Technical indicators and price prediction
|
||||
- Live trading integration with MEXC exchange via mexc-api
|
||||
- Demo mode for testing without real trades
|
||||
- Real-time data streaming via websockets
|
||||
- Performance tracking and visualization
|
||||
|
||||
## Prerequisites
|
||||
## Setup
|
||||
|
||||
- Python 3.8+
|
||||
- MEXC Exchange API credentials
|
||||
- GPU recommended but not required
|
||||
1. Clone the repository
|
||||
2. Install dependencies:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
3. Create a `.env` file in the root directory with your MEXC API keys:
|
||||
```
|
||||
MEXC_API_KEY=your_api_key_here
|
||||
MEXC_SECRET_KEY=your_secret_key_here
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yourusername/crypto-trading-bot.git
|
||||
cd crypto-trading-bot
|
||||
```
|
||||
2. Create a virtual environment:
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
3. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
4. Create a `.env` file in the project root with your MEXC API credentials:
|
||||
|
||||
```bash
|
||||
MEXC_API_KEY=your_api_key
|
||||
MEXC_API_SECRET=your_api_secret
|
||||
|
||||
|
||||
cuda support
|
||||
|
||||
```bash
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
## Usage
|
||||
|
||||
The bot can be run in three modes:
|
||||
The bot can be run in three different modes:
|
||||
|
||||
### Training Mode
|
||||
|
||||
```bash
|
||||
python main.py --mode train --episodes 1000
|
||||
Train the agent on historical data:
|
||||
|
||||
```
|
||||
python main.py --mode train --episodes 100
|
||||
```
|
||||
|
||||
### Evaluation Mode
|
||||
|
||||
```bash
|
||||
python main.py --mode eval --episodes 10
|
||||
Evaluate the trained agent on historical data:
|
||||
|
||||
```
|
||||
python main.py --mode evaluate
|
||||
```
|
||||
|
||||
### Live Trading Mode
|
||||
|
||||
```bash
|
||||
# Demo mode (simulated trading with real market data)
|
||||
python main.py --mode live --demo
|
||||
Run the bot in live trading mode:
|
||||
|
||||
# Real trading (actual trades on MEXC)
|
||||
```
|
||||
python main.py --mode live
|
||||
```
|
||||
|
||||
Demo mode simulates trading using real-time market data but does not execute actual trades. It still:
|
||||
- Logs all trading decisions and performance metrics
|
||||
- Updates the model based on market data (if in training mode)
|
||||
- Displays real-time analytics and position information
|
||||
- Calculates theoretical profits/losses
|
||||
- Saves performance data to TensorBoard
|
||||
To run in demo mode (no real trades):
|
||||
|
||||
This makes it perfect for testing strategies without financial risk.
|
||||
```
|
||||
python main.py --mode live --demo
|
||||
```
|
||||
|
||||
## Live Trading Implementation
|
||||
|
||||
The bot uses the mexc-api package to execute trades on the MEXC exchange. The implementation includes:
|
||||
|
||||
- Market order execution for opening and closing positions
|
||||
- Stop loss and take profit orders
|
||||
- Real-time balance updates
|
||||
- Trade history tracking
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters can be adjusted in `main.py`:
|
||||
You can adjust the following parameters in `main.py`:
|
||||
|
||||
- `INITIAL_BALANCE`: Starting balance for training/demo
|
||||
- `MAX_LEVERAGE`: Maximum leverage for trades
|
||||
- `STOP_LOSS_PERCENT`: Stop loss percentage
|
||||
- `TAKE_PROFIT_PERCENT`: Take profit percentage
|
||||
- `BATCH_SIZE`: Training batch size
|
||||
- `LEARNING_RATE`: Model learning rate
|
||||
- `STATE_SIZE`: Size of the state representation
|
||||
- `INITIAL_BALANCE`: Starting balance for simulation
|
||||
- `MAX_LEVERAGE`: Leverage to use for trading
|
||||
- `STOP_LOSS_PERCENT`: Default stop loss percentage
|
||||
- `TAKE_PROFIT_PERCENT`: Default take profit percentage
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The DQN model includes:
|
||||
- Input layer with technical indicators
|
||||
- LSTM layers for temporal pattern recognition
|
||||
- Multi-head attention mechanism
|
||||
- Dueling architecture for better Q-value estimation
|
||||
- Batch normalization for stable training
|
||||
|
||||
## Monitoring
|
||||
|
||||
Training progress can be monitored using TensorBoard:
|
||||
|
||||
|
||||
Training progress is logged to TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir=logs
|
||||
```
|
||||
|
||||
This will show:
|
||||
- Training rewards
|
||||
- Account balance
|
||||
- Win rate
|
||||
- Loss metrics
|
||||
|
||||
## Trading Strategy
|
||||
|
||||
The bot makes decisions based on:
|
||||
- Price action
|
||||
- Technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Historical patterns through LSTM
|
||||
- Risk management with stop-loss and take-profit
|
||||
|
||||
## Safety Features
|
||||
|
||||
- Demo mode for safe testing
|
||||
- Automatic stop-loss
|
||||
- Position size limits
|
||||
- Error handling for API calls
|
||||
- Logging of all actions
|
||||
|
||||
## Directory Structure
|
||||
├── main.py # Main bot implementation
|
||||
├── requirements.txt # Project dependencies
|
||||
├── .env # API credentials
|
||||
├── models/ # Saved model checkpoints
|
||||
├── runs/ # TensorBoard logs
|
||||
└── trading_bot.log # Activity logs
|
||||
## Architecture
|
||||
|
||||
- `main.py`: Main entry point and trading logic
|
||||
- `mexc_trading.py`: MEXC API integration for live trading using mexc-api
|
||||
- `models/`: Directory for saved model weights
|
||||
|
||||
## Warning
|
||||
|
||||
Cryptocurrency trading carries significant risks. This bot is for educational purposes and should not be used with real money without thorough testing and understanding of the risks involved.
|
||||
Trading cryptocurrencies involves significant risk. This bot is provided for educational purposes only. Use at your own risk.
|
||||
|
||||
## License
|
||||
|
||||
[MIT License](LICENSE)
|
||||
|
||||
The main changes I made:
|
||||
Fixed code block formatting by adding proper language identifiers
|
||||
Added missing closing code blocks
|
||||
Properly formatted directory structure
|
||||
Added complete sections that were cut off in the original
|
||||
Ensured consistent formatting throughout the document
|
||||
Added proper bash syntax highlighting for command examples
|
||||
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Edits/improvements
|
||||
|
||||
Fixes the shape mismatch by ensuring the state vector is exactly STATE_SIZE elements
|
||||
Adds robust error handling in the model's forward pass to handle mismatched inputs
|
||||
Adds a transformer encoder for more sophisticated pattern recognition
|
||||
Provides an expand_model method to increase model capacity while preserving learned weights
|
||||
Adds detailed logging about model size and shape mismatches
|
||||
The model now has:
|
||||
Configurable hidden layer sizes
|
||||
Transformer layers for complex pattern recognition
|
||||
LSTM layers for temporal patterns
|
||||
Attention mechanisms for focusing on important features
|
||||
Dueling architecture for better Q-value estimation
|
||||
With hidden_size=256, this model has about 1-2 million parameters. By increasing hidden_size to 512 or 1024, you can easily scale to 5-20 million parameters. For even larger models (billions of parameters), you would need to implement a more distributed architecture with multiple GPUs, which would require significant changes to the training loop.
|
||||
MIT
|
||||
|
@ -1,10 +1,11 @@
|
||||
numpy>=1.21.0
|
||||
numpy>=1.20.0
|
||||
pandas>=1.3.0
|
||||
matplotlib>=3.4.0
|
||||
torch>=1.9.0
|
||||
python-dotenv>=0.19.0
|
||||
scikit-learn>=0.24.0
|
||||
ccxt>=2.0.0
|
||||
python-dotenv>=0.19.0
|
||||
websockets>=10.0
|
||||
tensorboard>=2.6.0
|
||||
scikit-learn
|
||||
mplfinance
|
||||
tensorboard>=2.7.0
|
||||
mexc-api>=1.0.0
|
||||
asyncio>=3.4.3
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
23
crypto/gogo2/test_mexc_api.py
Normal file
23
crypto/gogo2/test_mexc_api.py
Normal file
@ -0,0 +1,23 @@
|
||||
from mexc_api.spot import Spot
|
||||
|
||||
def test_mexc_api():
|
||||
try:
|
||||
# Initialize client with empty API keys for public endpoints
|
||||
client = Spot("", "")
|
||||
|
||||
# Test server time endpoint
|
||||
server_time = client.market.server_time()
|
||||
print(f"Server time: {server_time}")
|
||||
|
||||
# Test ticker price endpoint
|
||||
ticker = client.market.ticker_price("ETHUSDT")
|
||||
print(f"ETH/USDT price: {ticker}")
|
||||
|
||||
print("MEXC API is working correctly!")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error testing MEXC API: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mexc_api()
|
34
crypto/gogo2/test_trading_client.py
Normal file
34
crypto/gogo2/test_trading_client.py
Normal file
@ -0,0 +1,34 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from mexc_trading import MexcTradingClient
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
async def test_trading_client():
|
||||
"""Test the MexcTradingClient functionality"""
|
||||
print("Initializing MexcTradingClient...")
|
||||
client = MexcTradingClient(symbol="ETH/USDT")
|
||||
|
||||
# Test getting market price
|
||||
print("Testing get_market_price...")
|
||||
price = await client.get_market_price()
|
||||
print(f"Current ETH/USDT price: {price}")
|
||||
|
||||
# If API keys are provided, test account balance
|
||||
if os.getenv('MEXC_API_KEY') and os.getenv('MEXC_SECRET_KEY'):
|
||||
print("Testing fetch_account_balance...")
|
||||
balance = await client.fetch_account_balance()
|
||||
print(f"Account balance: {balance} USDT")
|
||||
|
||||
print("Testing fetch_open_positions...")
|
||||
positions = await client.fetch_open_positions()
|
||||
print(f"Open positions: {positions}")
|
||||
else:
|
||||
print("No API keys provided. Skipping private endpoint tests.")
|
||||
|
||||
print("MexcTradingClient test completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_trading_client())
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Before Width: | Height: | Size: 170 KiB After Width: | Height: | Size: 307 KiB |
Binary file not shown.
Before Width: | Height: | Size: 86 KiB After Width: | Height: | Size: 84 KiB |
Loading…
x
Reference in New Issue
Block a user