fixes, lots of new ideas

This commit is contained in:
Dobromir Popov 2025-03-12 00:56:32 +02:00
parent d9f1bac11c
commit ad559d8c61
21 changed files with 13025 additions and 179 deletions

4
.gitignore vendored
View File

@ -38,4 +38,6 @@ crypto/gogo2/trading_bot.log
crypto/gogo2/checkpoints/trading_agent_episode_*.pt crypto/gogo2/checkpoints/trading_agent_episode_*.pt
*trading_agent_continuous_*.pt *trading_agent_continuous_*.pt
*trading_agent_episode_*.pt *trading_agent_episode_*.pt
crypto/gogo2/models/trading_agent_continuous_150.pt crypto/gogo2/models/trading_agent_continuous_*.pt
crypto/gogo2/visualizations/training_episode_*.png
crypto/gogo2/checkpoints/trading_agent_episode_*.pt

67
crypto/gogo2/_model.md Normal file
View File

@ -0,0 +1,67 @@
# Neural Network Architecture Analysis for Trading Bot
## Overview
This document provides a comprehensive analysis of the neural network architecture used in our trading bot system. The system consists of two main neural network components:
1. **Price Prediction Model** - Forecasts future price movements and extrema points
2. **DQN (Deep Q-Network)** - Makes trading decisions based on state representations
## 1. Price Prediction Model
### Architecture
```
PricePredictionModel(nn.Module)
├── Input Layer: [batch_size, seq_len, 2] (price, volume)
├── LSTM Layers: 2 stacked layers with hidden_size=128
├── Attention Mechanism: Self-attention with linear projections
├── Linear Layer 1: hidden_size → hidden_size
├── ReLU Activation
├── Linear Layer 2: hidden_size → output_size (5 future prices)
└── Output: [batch_size, output_size]
```
### Data Flow
**Inputs:**
- `price_history`: Sequence of historical prices [batch_size, seq_len]
- `volume_history`: Sequence of historical volumes [batch_size, seq_len]
**Preprocessing:**
- Normalization using MinMaxScaler (0-1 range)
- Reshaping to [batch_size, seq_len, 2] (price and volume features)
**Forward Pass:**
1. Input data passes through LSTM layers
2. Self-attention mechanism applied to LSTM outputs
3. Linear layers process the attended features
4. Output represents predicted prices for next 5 candles
**Outputs:**
- `predicted_prices`: Array of 5 future price predictions
- `predicted_extrema`: Binary indicators for potential price extrema points
## 2. DQN (Deep Q-Network)
### Architecture
```
DQN(nn.Module)
├── Input Layer: [batch_size, state_size]
├── Linear Layer 1: state_size → hidden_size (384)
├── ReLU Activation
├── LSTM Layers: 2 stacked layers with hidden_size=384
├── Multi-Head Attention: 4 attention heads
├── Linear Layer 2: hidden_size → hidden_size
├── ReLU Activation
├── Linear Layer 3: hidden_size → action_size (4)
└── Output: [batch_size, action_size] (Q-values for each action)
```
### Data Flow
**Inputs:**
- `state`: Current market state representation [batch_size, state_size]
- Price features (normalized prices, returns, volatility)
- Technical indicators (RSI, MACD, Stochastic,

View File

@ -4,6 +4,12 @@ ensure we use GPU if available to train faster. during training we need to have
do we call the trading api proerly when in live mode? in live mode - also validate the current ballance. also ensure trades are executed after executing the orders by checking the open orders.
our trading data chart (in tensorboard) does not properly displayed - the candles seems displayed multiple times but shifted in time. we also do not correctly show the buy/sell evens on the time axis. we do not show the predicted price on the chart.
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles 2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance) C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
@ -22,3 +28,13 @@ Backend tkagg is interactive backend. Turning interactive mode on.
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9% 2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
----------------
do we train the price prediction model by using old known candles and masking the latest to make him guess the next and then backpropagating the next already known candle ? like a transformer (gpt2) would do? or we use RL for that as well?
it seems the model is not learning a lot. we keep hovering about the same starting balance even after some time in training in continious mode
it seems we may need another NN model down the loop jut to predict the extremums of the price.
we may have to include a mechanism to calculate the extremums of the price retrospectively and to use that to bootstrap pre-train the model.

362
crypto/gogo2/archive.py Normal file
View File

@ -0,0 +1,362 @@
class MexcTradingClient:
def __init__(self, api_key, secret_key, symbol="ETH/USDT", leverage=50):
self.client = ccxt.mexc({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True,
})
self.symbol = symbol
self.leverage = leverage
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.stop_loss = 0
self.take_profit = 0
self.trades = []
def initialize_mexc_client(self, api_key, api_secret):
"""Initialize the MEXC API client"""
try:
from mexc_sdk import Spot
self.mexc_client = Spot(api_key=api_key, api_secret=api_secret)
# Test connection
self.mexc_client.ping()
logger.info("MEXC API client initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize MEXC API client: {e}")
logger.error(traceback.format_exc())
return False
async def fetch_account_balance(self):
"""Fetch actual account balance from MEXC API"""
if self.demo or not self.mexc_client:
# In demo mode, use simulated balance
return self.balance
try:
account_info = self.mexc_client.accountInfo()
if 'balances' in account_info:
# Find USDT balance
for asset in account_info['balances']:
if asset['asset'] == 'USDT':
return float(asset['free'])
logger.warning("Could not find USDT balance, using current simulated balance")
return self.balance
except Exception as e:
logger.error(f"Error fetching account balance: {e}")
logger.error(traceback.format_exc())
# Fallback to simulated balance in case of API error
return self.balance
async def fetch_open_positions(self):
"""Fetch actual open positions from MEXC API"""
if self.demo or not self.mexc_client:
# In demo mode, return current simulated position
return [{
'symbol': 'ETH/USDT',
'positionSide': 'LONG' if self.position == 'long' else 'SHORT' if self.position == 'short' else 'NONE',
'positionAmt': self.position_size / self.current_price if self.position != 'flat' else 0,
'entryPrice': self.entry_price,
'unrealizedProfit': self.calculate_unrealized_pnl()
}] if self.position != 'flat' else []
try:
# Fetch open positions
positions = self.mexc_client.openOrders('ETH/USDT')
return positions
except Exception as e:
logger.error(f"Error fetching open positions: {e}")
logger.error(traceback.format_exc())
# Fallback to simulated positions in case of API error
return []
def calculate_unrealized_pnl(self):
"""Calculate unrealized PnL for the current position"""
if self.position == 'flat':
return 0.0
position_value = self.position_size / self.entry_price
if self.position == 'long':
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
else: # short
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
# Apply leverage
pnl_percent *= self.leverage
return position_value * pnl_percent / 100
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
"""Open a new position using MEXC API in live mode, or simulate in demo mode"""
if self.demo or not self.mexc_client:
# In demo mode, simulate opening a position
self.position = position_type
self.position_size = size
self.entry_price = entry_price
self.entry_index = self.current_step
self.stop_loss = stop_loss
self.take_profit = take_profit
logger.info(f"DEMO: Opened {position_type.upper()} position at {entry_price} | " +
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
return True
try:
# In live mode, place actual orders via API
symbol = "ETHUSDT" # Format required by MEXC
side = "BUY" if position_type == 'long' else "SELL"
# Calculate quantity based on size and price
quantity = size / entry_price
# Place main order
order_result = self.mexc_client.newOrder(
symbol=symbol,
side=side,
orderType="MARKET",
quantity=quantity,
options={
"leverage": self.leverage,
"newOrderRespType": "FULL"
}
)
# Check if order executed
if order_result.get('status') == 'FILLED':
actual_entry_price = float(order_result.get('price', entry_price))
# Place stop loss order
sl_order = self.mexc_client.newOrder(
symbol=symbol,
side="SELL" if position_type == 'long' else "BUY",
orderType="STOP_LOSS",
quantity=quantity,
options={
"stopPrice": stop_loss,
"newOrderRespType": "ACK"
}
)
# Place take profit order
tp_order = self.mexc_client.newOrder(
symbol=symbol,
side="SELL" if position_type == 'long' else "BUY",
orderType="TAKE_PROFIT",
quantity=quantity,
options={
"stopPrice": take_profit,
"newOrderRespType": "ACK"
}
)
# Update local state
self.position = position_type
self.position_size = size
self.entry_price = actual_entry_price
self.entry_index = self.current_step
self.stop_loss = stop_loss
self.take_profit = take_profit
# Track orders
self.open_orders.extend([sl_order, tp_order])
self.order_history.append(order_result)
logger.info(f"LIVE: Opened {position_type.upper()} position at {actual_entry_price} | " +
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
return True
else:
logger.error(f"Failed to execute order: {order_result}")
return False
except Exception as e:
logger.error(f"Error opening position: {e}")
logger.error(traceback.format_exc())
return False
async def close_position(self, reason="manual_close"):
"""Close the current position using MEXC API in live mode, or simulate in demo mode"""
if self.position == 'flat':
logger.info("No position to close")
return False
if self.demo or not self.mexc_client:
# In demo mode, simulate closing a position
position_type = self.position
entry_price = self.entry_price
exit_price = self.current_price
position_size = self.position_size
# Calculate PnL
if position_type == 'long':
pnl_percent = (exit_price - entry_price) / entry_price * 100
else: # short
pnl_percent = (entry_price - exit_price) / entry_price * 100
# Apply leverage
pnl_percent *= self.leverage
# Calculate actual PnL
pnl_dollar = position_size * pnl_percent / 100
# Apply fees
pnl_dollar -= self.calculate_fees(position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': position_type,
'entry': entry_price,
'exit': exit_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': reason,
'leverage': self.leverage
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
logger.info(f"DEMO: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
return True
try:
# In live mode, close position via API
symbol = "ETHUSDT"
position_info = await self.fetch_open_positions()
if not position_info:
logger.warning("No open positions found to close")
self.position = 'flat'
return False
# First, cancel any existing stop loss/take profit orders
try:
self.mexc_client.cancelOpenOrders(symbol)
except Exception as e:
logger.warning(f"Error canceling open orders: {e}")
# Close the position with a market order
position_type = self.position
side = "SELL" if position_type == 'long' else "BUY"
quantity = self.position_size / self.current_price
# Execute order
order_result = self.mexc_client.newOrder(
symbol=symbol,
side=side,
orderType="MARKET",
quantity=quantity,
options={
"newOrderRespType": "FULL"
}
)
# Check if order executed
if order_result.get('status') == 'FILLED':
exit_price = float(order_result.get('price', self.current_price))
entry_price = self.entry_price
position_size = self.position_size
# Calculate PnL
if position_type == 'long':
pnl_percent = (exit_price - entry_price) / entry_price * 100
else: # short
pnl_percent = (entry_price - exit_price) / entry_price * 100
# Apply leverage
pnl_percent *= self.leverage
# Calculate actual PnL
pnl_dollar = position_size * pnl_percent / 100
# Apply fees
pnl_dollar -= self.calculate_fees(position_size)
# Update balance from API
self.balance = await self.fetch_account_balance()
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': position_type,
'entry': entry_price,
'exit': exit_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': reason,
'leverage': self.leverage,
'order_id': order_result.get('orderId')
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Track order history
self.order_history.append(order_result)
logger.info(f"LIVE: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
return True
else:
logger.error(f"Failed to close position: {order_result}")
return False
except Exception as e:
logger.error(f"Error closing position: {e}")
logger.error(traceback.format_exc())
return False

View File

@ -1 +1 @@
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 30, "timestamp": "2025-03-10T17:57:19.913481"} {"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 0, "timestamp": "2025-03-12T00:23:19.125190"}

View File

@ -0,0 +1,207 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class PricePredictionModel(nn.Module):
def __init__(self, input_dim=2, hidden_dim=128, num_layers=2, output_dim=5):
super(PricePredictionModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# LSTM layers
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
# Self-attention mechanism
self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
# Fully connected layer for price prediction
self.price_fc = nn.Linear(hidden_dim, output_dim)
# Fully connected layer for extrema prediction (high and low points)
self.extrema_fc = nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
def forward(self, x):
# x shape: (batch_size, seq_len, input_dim)
# LSTM forward pass
lstm_out, _ = self.lstm(x) # lstm_out: (batch_size, seq_len, hidden_dim)
# Self-attention
attn_output, _ = self.attention(lstm_out, lstm_out, lstm_out)
# Price prediction
price_pred = self.price_fc(attn_output[:, -1, :]) # Use the last time step
# Extrema prediction
extrema_logits = self.extrema_fc(attn_output[:, -1, :])
return price_pred, extrema_logits
class DQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(DQN, self).__init__()
# Feature extraction layers
self.feature_extraction = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
)
# Advantage stream
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, action_dim)
)
# Value stream
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, 1)
)
# Transformer for temporal dependencies
encoder_layers = TransformerEncoderLayer(d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
# LSTM for sequential decision making
self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
# Final layers
self.final_layers = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state, hidden=None):
# Extract features
features = self.feature_extraction(state)
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
# Transformer processing
transformer_out = self.transformer(features)
# LSTM processing
lstm_out, lstm_hidden = self.lstm(transformer_out)
# Dueling architecture
advantage = self.advantage_stream(features.squeeze(1))
value = self.value_stream(features.squeeze(1))
# Combine transformer, LSTM and dueling outputs
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
q_values = self.final_layers(combined)
# Dueling Q-value computation
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, lstm_hidden
def count_parameters(model):
total_params = 0
layer_params = {}
for name, param in model.named_parameters():
if param.requires_grad:
param_count = param.numel()
total_params += param_count
layer_params[name] = (param_count, param.shape)
return total_params, layer_params
def main():
# Initialize the Price Prediction Model
price_model = PricePredictionModel()
price_total_params, price_layer_params = count_parameters(price_model)
print(f"Price Prediction Model parameters: {price_total_params:,}")
print("\nPrice Prediction Model Layers:")
for name, (count, shape) in price_layer_params.items():
print(f"{name}: {count:,} (shape: {shape})")
# Initialize the DQN Model with typical dimensions
state_dim = 50 # Typical state dimension for the trading bot
action_dim = 3 # Typical action dimension (buy, sell, hold)
dqn_model = DQN(state_dim=state_dim, action_dim=action_dim)
dqn_total_params, dqn_layer_params = count_parameters(dqn_model)
# Count parameters by category
feature_extraction_params = sum(count for name, (count, _) in dqn_layer_params.items() if "feature_extraction" in name)
advantage_value_params = sum(count for name, (count, _) in dqn_layer_params.items() if "advantage_stream" in name or "value_stream" in name)
transformer_params = sum(count for name, (count, _) in dqn_layer_params.items() if "transformer" in name)
lstm_params = sum(count for name, (count, _) in dqn_layer_params.items() if "lstm" in name and "transformer" not in name)
final_layers_params = sum(count for name, (count, _) in dqn_layer_params.items() if "final_layers" in name)
print(f"\nDQN Model parameters: {dqn_total_params:,}")
# Create sets to track which parameters we've printed
printed_params = set()
# Print DQN layers in groups to avoid output truncation
print(f"\nDQN Model Layers (Feature Extraction): {feature_extraction_params:,} parameters")
for name, (count, shape) in dqn_layer_params.items():
if "feature_extraction" in name:
print(f"{name}: {count:,} (shape: {shape})")
printed_params.add(name)
print(f"\nDQN Model Layers (Advantage & Value Streams): {advantage_value_params:,} parameters")
for name, (count, shape) in dqn_layer_params.items():
if "advantage_stream" in name or "value_stream" in name:
print(f"{name}: {count:,} (shape: {shape})")
printed_params.add(name)
print(f"\nDQN Model Layers (Transformer): {transformer_params:,} parameters")
for name, (count, shape) in dqn_layer_params.items():
if "transformer" in name:
print(f"{name}: {count:,} (shape: {shape})")
printed_params.add(name)
print(f"\nDQN Model Layers (LSTM): {lstm_params:,} parameters")
for name, (count, shape) in dqn_layer_params.items():
if "lstm" in name and "transformer" not in name:
print(f"{name}: {count:,} (shape: {shape})")
printed_params.add(name)
print(f"\nDQN Model Layers (Final Layers): {final_layers_params:,} parameters")
for name, (count, shape) in dqn_layer_params.items():
if "final_layers" in name:
print(f"{name}: {count:,} (shape: {shape})")
printed_params.add(name)
# Print any remaining parameters that weren't caught by the categories above
remaining_params = set(dqn_layer_params.keys()) - printed_params
if remaining_params:
remaining_params_count = sum(dqn_layer_params[name][0] for name in remaining_params)
print(f"\nDQN Model Layers (Other): {remaining_params_count:,} parameters")
for name in remaining_params:
count, shape = dqn_layer_params[name]
print(f"{name}: {count:,} (shape: {shape})")
# Total parameters across both models
print(f"\nTotal parameters (both models): {price_total_params + dqn_total_params:,}")
# Print summary of parameter distribution
print("\nParameter Distribution Summary:")
print(f"Price Prediction Model: {price_total_params:,} parameters ({price_total_params/(price_total_params + dqn_total_params)*100:.1f}%)")
print(f"DQN Model: {dqn_total_params:,} parameters ({dqn_total_params/(price_total_params + dqn_total_params)*100:.1f}%)")
print("\nDQN Model Breakdown:")
print(f"- Feature Extraction: {feature_extraction_params:,} parameters ({feature_extraction_params/dqn_total_params*100:.1f}%)")
print(f"- Advantage & Value Streams: {advantage_value_params:,} parameters ({advantage_value_params/dqn_total_params*100:.1f}%)")
print(f"- Transformer: {transformer_params:,} parameters ({transformer_params/dqn_total_params*100:.1f}%)")
print(f"- LSTM: {lstm_params:,} parameters ({lstm_params/dqn_total_params*100:.1f}%)")
print(f"- Final Layers: {final_layers_params:,} parameters ({final_layers_params/dqn_total_params*100:.1f}%)")
# Verify that all parameters are accounted for
total_by_category = feature_extraction_params + advantage_value_params + transformer_params + lstm_params + final_layers_params
if remaining_params:
total_by_category += remaining_params_count
print(f"\nSum of all categories: {total_by_category:,} parameters")
print(f"Difference from total: {dqn_total_params - total_by_category:,} parameters")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,42 @@
def fix_indentation():
with open('main.py', 'r') as f:
lines = f.readlines()
# Fix indentation for the problematic sections
fixed_lines = []
# Find the try block that starts at line 1693
try_start_line = 1693
try_block_found = False
in_try_block = False
for i, line in enumerate(lines):
# Check if we're at the try statement
if i+1 == try_start_line and 'try:' in line:
try_block_found = True
in_try_block = True
fixed_lines.append(line)
# Fix the indentation of the experiences line
elif i+1 == 1695 and line.strip().startswith('experiences = self.memory.sample(BATCH_SIZE)'):
# Add proper indentation (4 spaces)
fixed_lines.append(' ' + line.lstrip())
# Check if we're at the end of the try block without an except
elif try_block_found and in_try_block and i+1 > try_start_line and line.strip() and not line.startswith(' '):
# We've reached the end of the try block without an except, add one
fixed_lines.append(' except Exception as e:\n')
fixed_lines.append(' logger.error(f"Error during learning: {e}")\n')
fixed_lines.append(' logger.error(f"Traceback: {traceback.format_exc()}")\n')
fixed_lines.append(' return None\n\n')
in_try_block = False
fixed_lines.append(line)
else:
fixed_lines.append(line)
# Write the fixed content back to the file
with open('main.py', 'w') as f:
f.writelines(fixed_lines)
print("Indentation fixed!")
if __name__ == "__main__":
fix_indentation()

View File

@ -0,0 +1,22 @@
import re
def fix_try_blocks():
with open('main.py', 'r') as f:
content = f.read()
# Find all try blocks without except or finally
pattern = r'(\s+)try:\s*\n((?:\1\s+.*\n)+?)(?!\1\s*except|\1\s*finally)'
# Replace with try-except blocks
fixed_content = re.sub(pattern,
r'\1try:\n\2\1except Exception as e:\n\1 logger.error(f"Error: {e}")\n\1 logger.error(f"Traceback: {traceback.format_exc()}")\n\1 return None\n\n',
content)
# Write the fixed content back to the file
with open('main.py', 'w') as f:
f.write(fixed_content)
print("Try blocks fixed!")
if __name__ == "__main__":
fix_try_blocks()

View File

@ -1151,10 +1151,10 @@ class TradingEnvironment:
# Reward based on PnL # Reward based on PnL
if pnl_dollar > 0: if pnl_dollar > 0:
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
self.win_count += 1 self.win_count += 1
else: else:
reward = -1.0 # Negative reward for loss reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
self.loss_count += 1 self.loss_count += 1
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
@ -1238,10 +1238,15 @@ class TradingEnvironment:
# Reward based on PnL # Reward based on PnL
if pnl_dollar > 0: if pnl_dollar > 0:
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
self.win_count += 1 self.win_count += 1
# Extra reward for closing at a predicted high
if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
reward += 1.0
logger.info("Closing long at predicted high - additional reward")
else: else:
reward = -1.0 # Negative reward for loss reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
self.loss_count += 1 self.loss_count += 1
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
@ -1296,15 +1301,15 @@ class TradingEnvironment:
# Reward based on PnL # Reward based on PnL
if pnl_dollar > 0: if pnl_dollar > 0:
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
self.win_count += 1 self.win_count += 1
# Extra reward for closing at a predicted high # Extra reward for closing at a predicted high
if hasattr(self, 'has_predicted_high') and self.has_predicted_high: if hasattr(self, 'has_predicted_high') and self.has_predicted_high:
reward += 0.5 reward += 1.0
logger.info("Closing long at predicted high - additional reward") logger.info("Closing long at predicted high - additional reward")
else: else:
reward = -1.0 # Negative reward for loss reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
self.loss_count += 1 self.loss_count += 1
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
@ -1354,15 +1359,15 @@ class TradingEnvironment:
# Reward based on PnL # Reward based on PnL
if pnl_dollar > 0: if pnl_dollar > 0:
reward = 1.0 + pnl_dollar / 10 # Positive reward for profit reward = 2.0 + pnl_dollar * 0.5 # Increased positive reward for profit
self.win_count += 1 self.win_count += 1
# Extra reward for closing at a predicted low # Extra reward for closing at a predicted low
if hasattr(self, 'has_predicted_low') and self.has_predicted_low: if hasattr(self, 'has_predicted_low') and self.has_predicted_low:
reward += 0.5 reward += 1.0
logger.info("Closing short at predicted low - additional reward") logger.info("Closing short at predicted low - additional reward")
else: else:
reward = -1.0 # Negative reward for loss reward = -2.0 - abs(pnl_dollar) * 0.3 # Stronger negative reward for loss
self.loss_count += 1 self.loss_count += 1
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
@ -1378,9 +1383,9 @@ class TradingEnvironment:
# Add reward based on direct PnL change # Add reward based on direct PnL change
balance_change = self.balance - prev_balance balance_change = self.balance - prev_balance
if balance_change > 0: if balance_change > 0:
reward += balance_change * 0.5 # Positive reward for making money reward += balance_change * 1.0 # Increased positive reward for making money
else: else:
reward += balance_change * 1.0 # Stronger negative reward for losing money reward += balance_change * 2.0 # Stronger negative reward for losing money
# Add reward for predicted price movement alignment # Add reward for predicted price movement alignment
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
@ -1611,9 +1616,15 @@ class TradingEnvironment:
def initialize_price_predictor(self, device="cpu"): def initialize_price_predictor(self, device="cpu"):
"""Initialize the price prediction model""" """Initialize the price prediction model"""
# Only create a new model if one doesn't already exist
if not hasattr(self, 'price_predictor') or self.price_predictor is None:
self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
self.price_predictor.to(device) self.price_predictor.to(device)
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3)
else:
# If model exists, just ensure it's on the right device
self.price_predictor.to(device)
self.predicted_prices = np.array([]) self.predicted_prices = np.array([])
self.predicted_extrema = np.array([]) self.predicted_extrema = np.array([])
self.extrema_threshold = 0.7 # Threshold for extrema prediction confidence self.extrema_threshold = 0.7 # Threshold for extrema prediction confidence
@ -1766,16 +1777,16 @@ class TradingEnvironment:
return fee return fee
# Ensure GPU usage if available # Ensure GPU usage if available
def get_device(device_preference='gpu'): def get_device(device_preference='auto'):
"""Get the device to use (GPU or CPU) based on preference and availability""" """Get the device to use (GPU or CPU) based on preference and availability"""
if device_preference.lower() == 'gpu' and torch.cuda.is_available(): if device_preference.lower() in ['gpu', 'auto'] and torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
# Set default tensor type to float32 for CUDA # Set default tensor type to float32 for CUDA
torch.set_default_tensor_type(torch.FloatTensor) torch.set_default_tensor_type(torch.FloatTensor)
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
if device_preference.lower() == 'gpu': if device_preference.lower() in ['gpu', 'auto']:
logger.info("GPU requested but not available, using CPU instead") logger.info("GPU requested but not available, using CPU instead")
else: else:
logger.info("Using CPU as requested") logger.info("Using CPU as requested")
@ -1952,7 +1963,7 @@ class Agent:
# Use mixed precision for forward/backward passes # Use mixed precision for forward/backward passes
if self.device.type == "cuda": if self.device.type == "cuda":
with amp.autocast(): with amp.autocast(device_type='cuda'):
# Compute Q values # Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
@ -2943,10 +2954,12 @@ async def main():
import traceback import traceback
parser = argparse.ArgumentParser(description='Run the trading bot') parser = argparse.ArgumentParser(description='Run the trading bot')
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'], help='Mode to run the bot in') parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live', 'continuous'], help='Mode to run the bot in')
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train for') parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train for')
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)') parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
parser.add_argument('--device', type=str, default='auto', choices=['cpu', 'gpu', 'auto'], help='Device to use for training') parser.add_argument('--device', type=str, default='auto', choices=['cpu', 'gpu', 'auto'], help='Device to use for training')
parser.add_argument('--refresh-data', '--refresh_data', dest='refresh_data', action='store_true', help='Refresh data at the start of each episode')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for data (e.g., 1s, 1m, 5m, 15m, 1h)')
args = parser.parse_args() args = parser.parse_args()
# Set device # Set device
@ -2995,6 +3008,247 @@ async def main():
results = evaluate_agent(agent, env, num_episodes=10) results = evaluate_agent(agent, env, num_episodes=10)
logger.info(f"Evaluation results: {results}") logger.info(f"Evaluation results: {results}")
elif args.mode == 'continuous':
# Continuous training mode - train indefinitely with data refreshing
logger.info("Starting continuous training mode...")
# Set refresh_data to True for continuous mode
args.refresh_data = True
# Create directories for continuous models
os.makedirs("models", exist_ok=True)
# Track best PnL for model selection
best_pnl = float('-inf')
best_pnl_model_path = "models/trading_agent_best_pnl.pt"
# Load the best PnL model if it exists
if os.path.exists(best_pnl_model_path):
logger.info(f"Loading best PnL model: {best_pnl_model_path}")
agent.load(best_pnl_model_path)
# Try to load best PnL value from checkpoint file
checkpoint_info_path = "checkpoints/best_metrics.json"
if os.path.exists(checkpoint_info_path):
with open(checkpoint_info_path, 'r') as f:
best_metrics = json.load(f)
best_pnl = best_metrics.get('best_pnl', best_pnl)
logger.info(f"Loaded best PnL from checkpoint: ${best_pnl:.2f}")
# Initialize episode counter
episode = 0
# Get timeframe from args
timeframe = args.timeframe
logger.info(f"Using timeframe: {timeframe}")
# Initialize TensorBoard writer
from torch.utils.tensorboard import SummaryWriter
tensorboard_dir = f"runs/continuous_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(tensorboard_dir)
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
# Attach writer to agent
agent.writer = writer
# Initialize stats dictionary for plotting
stats = {
'episode_rewards': [],
'episode_profits': [],
'win_rates': [],
'trade_counts': [],
'prediction_accuracies': []
}
# Train continuously
try:
while True:
logger.info(f"Continuous training - Episode {episode}")
# Refresh data from exchange with the specified timeframe
logger.info(f"Refreshing market data with timeframe {timeframe}...")
await env.fetch_new_data(exchange, "ETH/USDT", timeframe, 100)
# Reset environment
state = env.reset()
# Initialize price predictor if not already initialized
if not hasattr(env, 'price_predictor') or env.price_predictor is None:
logger.info("Initializing price predictor...")
env.initialize_price_predictor(device=agent.device)
# Initialize episode variables
episode_reward = 0
done = False
# Train price predictor
prediction_loss, extrema_loss = env.train_price_predictor()
# Update price predictions
env.update_price_predictions()
# Training loop for this episode
while not done:
# Select action
action = agent.select_action(state)
# Take action
next_state, reward, done = env.step(action)
# Store experience
agent.memory.push(state, action, reward, next_state, done)
# Learn from experience
loss = agent.learn()
# Update state and reward
state = next_state
episode_reward += reward
# Calculate win rate
total_trades = env.win_count + env.loss_count
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
# Calculate prediction accuracy
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
# Compare predictions with actual prices
actual_prices = env.features['price'][-len(env.predicted_prices):]
prediction_errors = np.abs(env.predicted_prices - actual_prices) / actual_prices
prediction_accuracy = 100 * (1 - np.mean(prediction_errors))
else:
prediction_accuracy = 0
# Update stats
stats['episode_rewards'].append(episode_reward)
stats['episode_profits'].append(env.episode_pnl)
stats['win_rates'].append(win_rate)
stats['trade_counts'].append(total_trades)
stats['prediction_accuracies'].append(prediction_accuracy)
# Log to TensorBoard
writer.add_scalar('Reward/continuous', episode_reward, episode)
writer.add_scalar('Balance/continuous', env.balance, episode)
writer.add_scalar('WinRate/continuous', win_rate, episode)
writer.add_scalar('PnL/episode', env.episode_pnl, episode)
writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
writer.add_scalar('PredictionLoss', prediction_loss, episode)
writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
# Log OHLCV data to TensorBoard every 5 episodes
if episode % 5 == 0:
# Create a DataFrame from the environment's data
df_ohlcv = pd.DataFrame([{
'timestamp': candle['timestamp'],
'open': candle['open'],
'high': candle['high'],
'low': candle['low'],
'close': candle['close'],
'volume': candle['volume']
} for candle in env.data[-100:]]) # Use last 100 candles
# Convert timestamp to datetime
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
df_ohlcv.set_index('timestamp', inplace=True)
# Extract buy/sell signals from trades
buy_signals = []
sell_signals = []
if hasattr(env, 'trades') and env.trades:
for trade in env.trades:
if 'entry_time' in trade and 'entry' in trade:
if trade['type'] == 'long':
# Buy signal
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
buy_signals.append((entry_time, trade['entry']))
# Sell signal if closed
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
sell_signals.append((exit_time, trade['exit']))
elif trade['type'] == 'short':
# Sell short signal
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
sell_signals.append((entry_time, trade['entry']))
# Buy to cover signal if closed
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
buy_signals.append((exit_time, trade['exit']))
# Log to TensorBoard
log_ohlcv_to_tensorboard(
writer,
df_ohlcv,
buy_signals,
sell_signals,
episode,
tag_prefix=f"continuous_episode_{episode}"
)
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}")
# Create visualization every 10 episodes
if episode % 10 == 0:
# Create visualization
os.makedirs("visualizations", exist_ok=True)
visualize_training_results(env, agent, episode)
# Save model
model_path = f"models/trading_agent_continuous_{episode}.pt"
agent.save(model_path)
logger.info(f"Saved continuous model: {model_path}")
# Plot training results
plot_training_results(stats)
# Save best PnL model
if env.episode_pnl > best_pnl:
best_pnl = env.episode_pnl
agent.save(best_pnl_model_path)
logger.info(f"New best PnL model saved: ${env.episode_pnl:.2f}")
# Save best metrics to resume training if interrupted
best_metrics = {
'best_pnl': float(best_pnl),
'last_episode': episode,
'timestamp': datetime.datetime.now().isoformat()
}
os.makedirs("checkpoints", exist_ok=True)
with open("checkpoints/best_metrics.json", 'w') as f:
json.dump(best_metrics, f)
# Update target network
agent.update_target_network()
# Increment episode counter
episode += 1
# Sleep briefly to prevent overwhelming the system
# Use shorter sleep for shorter timeframes
if timeframe.endswith('s'):
await asyncio.sleep(0.1) # Very short sleep for second-based timeframes
else:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Continuous training stopped by user")
# Save final model
agent.save("models/trading_agent_continuous_final.pt")
# Close TensorBoard writer
writer.close()
except Exception as e:
logger.error(f"Error in continuous training: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
# Save emergency model
agent.save(f"models/trading_agent_continuous_emergency_{episode}.pt")
# Close TensorBoard writer
writer.close()
elif args.mode == 'evaluate': elif args.mode == 'evaluate':
# Load the best model # Load the best model
agent.load("models/trading_agent_best_pnl.pt") agent.load("models/trading_agent_best_pnl.pt")

View File

@ -0,0 +1,304 @@
import os
import json
import logging
import traceback
import numpy as np
from dotenv import load_dotenv
from mexc_api.spot import Spot
from mexc_api.common.enums import Side, OrderType
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('mexc_trading')
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
class MexcTradingClient:
"""Client for executing trades on MEXC exchange using the official API"""
def __init__(self, api_key=None, api_secret=None, symbol="ETH/USDT", leverage=1):
"""Initialize the MEXC trading client"""
self.api_key = api_key or MEXC_API_KEY
self.api_secret = api_secret or MEXC_SECRET_KEY
# Ensure API keys are not None
if not self.api_key or not self.api_secret:
logger.warning("API keys not provided. Using empty strings for public endpoints only.")
self.api_key = ""
self.api_secret = ""
self.symbol = symbol
self.formatted_symbol = symbol.replace('/', '') # MEXC requires no slash
self.leverage = leverage
self.client = None
self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0
self.entry_price = 0
self.stop_loss = 0
self.take_profit = 0
self.open_orders = []
self.order_history = []
self.trades = []
self.win_count = 0
self.loss_count = 0
# Initialize the MEXC API client
self.initialize_client()
def initialize_client(self):
"""Initialize the MEXC API client using the API"""
try:
self.client = Spot(self.api_key, self.api_secret)
# Test connection
server_time = self.client.market.server_time()
logger.info(f"MEXC API client initialized successfully. Server time: {server_time}")
return True
except Exception as e:
logger.error(f"Failed to initialize MEXC API client: {e}")
logger.error(traceback.format_exc())
return False
async def fetch_account_balance(self):
"""Fetch account balance from MEXC API"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot fetch account balance.")
return 0
account_info = self.client.account.account_info()
if 'balances' in account_info:
# Find USDT balance
for asset in account_info['balances']:
if asset['asset'] == 'USDT':
return float(asset['free'])
logger.warning("Could not find USDT balance")
return 0
except Exception as e:
logger.error(f"Error fetching account balance: {e}")
return 0
async def fetch_open_positions(self):
"""Fetch open positions from MEXC API"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot fetch open positions.")
return []
# Fetch open orders
open_orders = self.client.account.open_orders(self.formatted_symbol)
return open_orders
except Exception as e:
logger.error(f"Error fetching open positions: {e}")
return []
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
"""Open a new position using MEXC API"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot open position.")
return False
# Calculate quantity based on size and price
quantity = size / entry_price
# Round quantity to appropriate precision
quantity = round(quantity, 4) # Adjust precision as needed for your asset
# Determine order side
side = Side.BUY if position_type == 'long' else Side.SELL
logger.info(f"Opening {position_type} position: {quantity} {self.symbol} at market price")
# Place market order
order_result = self.client.account.new_order(
self.formatted_symbol,
side,
OrderType.MARKET,
str(quantity)
)
logger.info(f"Market order result: {order_result}")
# Check if order was filled
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
# Get actual entry price
actual_entry_price = float(order_result.get('price', entry_price))
# Place stop loss order
sl_side = Side.SELL if position_type == 'long' else Side.BUY
sl_order = self.client.account.new_order(
self.formatted_symbol,
sl_side,
OrderType.STOP_LOSS_LIMIT,
str(quantity),
price=str(stop_loss),
stop_price=str(stop_loss),
time_in_force="GTC"
)
logger.info(f"Stop loss order placed: {sl_order}")
# Place take profit order
tp_side = Side.SELL if position_type == 'long' else Side.BUY
tp_order = self.client.account.new_order(
self.formatted_symbol,
tp_side,
OrderType.TAKE_PROFIT_LIMIT,
str(quantity),
price=str(take_profit),
stop_price=str(take_profit),
time_in_force="GTC"
)
logger.info(f"Take profit order placed: {tp_order}")
# Update local state
self.position = position_type
self.position_size = size
self.entry_price = actual_entry_price
self.stop_loss = stop_loss
self.take_profit = take_profit
# Track orders
self.open_orders.extend([sl_order, tp_order])
self.order_history.append(order_result)
logger.info(f"Successfully opened {position_type} position at {actual_entry_price}")
return True
else:
logger.error(f"Failed to open position: {order_result}")
return False
except Exception as e:
logger.error(f"Error opening position: {e}")
logger.error(traceback.format_exc())
return False
async def close_position(self, reason="manual_close"):
"""Close an existing position"""
if self.position == 'flat':
logger.info("No position to close")
return True
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot close position.")
return False
# First, cancel any existing stop loss/take profit orders
try:
self.client.account.cancel_open_orders(self.formatted_symbol)
logger.info("Canceled all open orders")
except Exception as e:
logger.warning(f"Error canceling open orders: {e}")
# Determine order side (opposite of position)
side = Side.SELL if self.position == 'long' else Side.BUY
# Calculate quantity
quantity = self.position_size / self.entry_price
# Round quantity to appropriate precision
quantity = round(quantity, 4) # Adjust precision as needed
logger.info(f"Closing {self.position} position: {quantity} {self.symbol} at market price")
# Execute market order to close position
order_result = self.client.account.new_order(
self.formatted_symbol,
side,
OrderType.MARKET,
str(quantity)
)
logger.info(f"Close order result: {order_result}")
# Check if order was filled
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
# Get actual exit price
exit_price = float(order_result.get('price', 0))
# Calculate PnL
if self.position == 'long':
pnl_percent = (exit_price - self.entry_price) / self.entry_price * 100
else: # short
pnl_percent = (self.entry_price - exit_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Record trade
self.trades.append({
'type': self.position,
'entry': self.entry_price,
'exit': exit_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'reason': reason,
'order_id': order_result.get('orderId')
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Track order history
self.order_history.append(order_result)
logger.info(f"Closed {self.position} position at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
return True
else:
logger.error(f"Failed to close position: {order_result}")
return False
except Exception as e:
logger.error(f"Error closing position: {e}")
logger.error(traceback.format_exc())
return False
async def check_order_status(self, order_id):
"""Check the status of a specific order"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot check order status.")
return None
order_status = self.client.account.query_order(
self.formatted_symbol,
order_id=order_id
)
return order_status
except Exception as e:
logger.error(f"Error checking order status: {e}")
return None
async def get_market_price(self):
"""Get current market price for the symbol"""
try:
ticker = self.client.market.ticker_price(self.formatted_symbol)
if isinstance(ticker, list) and len(ticker) > 0 and 'price' in ticker[0]:
return float(ticker[0]['price'])
elif isinstance(ticker, dict) and 'price' in ticker:
return float(ticker['price'])
else:
logger.error(f"Unexpected ticker format: {ticker}")
return None
except Exception as e:
logger.error(f"Error getting market price: {e}")
return None

View File

@ -0,0 +1,456 @@
# mexc-api-sdk
MEXC Official Market and trade api sdk, easy to connection and send request to MEXC open api !
## Prerequisites
- To use our SDK you have to install nodejs LTS (https://aws.github.io/jsii/user-guides/lib-user/)
## Installation
1.
```
git clone https://github.com/mxcdevelop/mexc-api-sdk.git
```
2. cd dist/{language} and unzip the file
3. we offer five language : dotnet, go, java, js, python
## Table of APIS
- [Init](#init)
- [Market](#market)
- [Ping](#ping)
- [Check Server Time](#check-server-time)
- [Exchange Information](#exchange-information)
- [Recent Trades List](#recent-trades-list)
- [Order Book](#order-book)
- [Old Trade Lookup](#old-trade-lookup)
- [Aggregate Trades List](#aggregate-trades-list)
- [kline Data](#kline-data)
- [Current Average Price](#current-average-price)
- [24hr Ticker Price Change Statistics](#24hr-ticker-price-change-statistics)
- [Symbol Price Ticker](#symbol-price-ticker)
- [Symbol Order Book Ticker](#symbol-order-book-ticker)
- [Trade](#trade)
- [Test New Order](#test-new-order)
- [New Order](#new-order)
- [cancel-order](#cancel-order)
- [Cancel all Open Orders on a Symbol](#cancel-all-open-orders-on-a-symbol)
- [Query Order](#query-order)
- [Current Open Orders](#current-open-orders)
- [All Orders](#all-orders)
- [Account Information](#account-information)
- [Account Trade List](#account-trade-list)
## Init
```javascript
//Javascript
import * as Mexc from 'mexc-sdk';
const apiKey = 'apiKey'
const apiSecret = 'apiSecret'
const client = new Mexc.Spot(apiKey, apiSecret);
```
```go
// Go
package main
import (
"fmt"
"mexc-sdk/mexcsdk"
)
func main() {
apiKey := "apiKey"
apiSecret := "apiSecret"
spot := mexcsdk.NewSpot(apiKey, apiSecret)
}
```
```python
# python
from mexc_sdk import Spot
spot = Spot(api_key='apiKey', api_secret='apiSecret')
```
```java
// java
import Mexc.Sdk.*;
class MyClass {
public static void main(String[] args) {
String apiKey= "apiKey";
String apiSecret= "apiSecret";
Spot mySpot = new Spot(apiKey, apiSecret);
}
}
```
```C#
// dotnet
using System;
using System.Collections.Generic;
using Mxc.Sdk;
namespace dotnet
{
class Program
{
static void Main(string[] args)
{
string apiKey = "apiKey";
string apiSecret= "apiSecret";
var spot = new Spot(apiKey, apiSecret);
}
}
}
```
## Market
### Ping
```javascript
client.ping()
```
### Check Server Time
```javascript
client.time()
```
### Exchange Information
```javascript
client.exchangeInfo(options: any)
options:{symbol, symbols}
/**
* choose one parameter
*
* symbol :
* example "BNBBTC";
*
* symbols :
* array of symbol
* example ["BTCUSDT","BNBBTC"];
*
*/
```
### Recent Trades List
```javascript
client.trades(symbol: string, options: any = { limit: 500 })
options:{limit}
/**
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### Order Book
```javascript
client.depth(symbol: string, options: any = { limit: 100 })
options:{limit}
/**
* limit :
* Number of returned data
* Default 100;
* max 5000;
* Valid:[5, 10, 20, 50, 100, 500, 1000, 5000]
*
*/
```
### Old Trade Lookup
```javascript
client.historicalTrades(symbol: string, options: any = { limit: 500 })
options:{limit, fromId}
/**
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
* fromId:
* Trade id to fetch from. Default gets most recent trades
*
*/
```
### Aggregate Trades List
```javascript
client.aggTrades(symbol: string, options: any = { limit: 500 })
options:{fromId, startTime, endTime, limit}
/**
*
* fromId :
* id to get aggregate trades from INCLUSIVE
*
* startTime:
* start at
*
* endTime:
* end at
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### kline Data
```javascript
client.klines(symbol: string, interval: string, options: any = { limit: 500 })
options:{ startTime, endTime, limit}
/**
*
* interval :
* m :minute;
* h :Hour;
* d :day;
* w :week;
* M :month
* example : "1m"
*
* startTime :
* start at
*
* endTime :
* end at
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### Current Average Price
```javascript
client.avgPrice(symbol: string)
```
### 24hr Ticker Price Change Statistics
```javascript
client.ticker24hr(symbol?: string)
```
### Symbol Price Ticker
```javascript
client.tickerPrice(symbol?: string)
```
### Symbol Order Book Ticker
```javascript
client.bookTicker(symbol?: string)
```
## Trade
### Test New Order
```javascript
client.newOrderTest(symbol: string, side: string, orderType: string, options: any = {})
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
/**
*
* side:
* Order side
* ENUM:
* BUY
* SELL
*
* orderType:
* Order type
* ENUM:
* LIMIT
* MARKET
* STOP_LOSS
* STOP_LOSS_LIMIT
* TAKE_PROFIT
* TAKE_PROFIT_LIMIT
* LIMIT_MAKER
*
* timeInForce :
* How long an order will be active before expiration.
* GTC: Active unless the order is canceled
* IOC: Order will try to fill the order as much as it can before the order expires
* FOK: Active unless the full order cannot be filled upon execution.
*
* quantity :
* target quantity
*
* quoteOrderQty :
* Specify the total spent or received
*
* price :
* target price
*
* newClientOrderId :
* A unique id among open orders. Automatically generated if not sent
*
* stopPrice :
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
*
* icebergQty :
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
*
* newOrderRespType :
* Set the response JSON. ACK, RESULT, or FULL;
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
*
* recvWindow :
* Delay accept time
* The value cannot be greater than 60000
* defaults: 5000
*
*/
```
### New Order
```javascript
client.newOrder(symbol: string, side: string, orderType: string, options: any = {})
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
/**
*
* side:
* Order side
* ENUM:
* BUY
* SELL
*
* orderType:
* Order type
* ENUM:
* LIMIT
* MARKET
* STOP_LOSS
* STOP_LOSS_LIMIT
* TAKE_PROFIT
* TAKE_PROFIT_LIMIT
* LIMIT_MAKER
*
* timeInForce :
* How long an order will be active before expiration.
* GTC: Active unless the order is canceled
* IOC: Order will try to fill the order as much as it can before the order expires
* FOK: Active unless the full order cannot be filled upon execution.
*
* quantity :
* target quantity
*
* quoteOrderQty :
* Specify the total spent or received
*
* price :
* target price
*
* newClientOrderId :
* A unique id among open orders. Automatically generated if not sent
*
* stopPrice :
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
*
* icebergQty :
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
*
* newOrderRespType :
* Set the response JSON. ACK, RESULT, or FULL;
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
*
* recvWindow :
* Delay accept time
* The value cannot be greater than 60000
* defaults: 5000
*
*/
```
### cancel-order
```javascript
client.cancelOrder(symbol: string, options:any = {})
options:{ orderId, origClientOrderId, newClientOrderId}
/**
*
* Either orderId or origClientOrderId must be sent
*
* orderId:
* target orderId
*
* origClientOrderId:
* target origClientOrderId
*
* newClientOrderId:
* Used to uniquely identify this cancel. Automatically generated by default.
*
*/
```
### Cancel all Open Orders on a Symbol
```javascript
client.cancelOpenOrders(symbol: string)
```
### Query Order
```javascript
client.queryOrder(symbol: string, options:any = {})
options:{ orderId, origClientOrderId}
/**
*
* Either orderId or origClientOrderId must be sent
*
* orderId:
* target orderId
*
* origClientOrderId:
* target origClientOrderId
*
*/
```
### Current Open Orders
```javascript
client.openOrders(symbol: string)
```
### All Orders
```javascript
client.allOrders(symbol: string, options: any = { limit: 500 })
options:{ orderId, startTime, endTime, limit}
/**
*
* orderId:
* target orderId
*
* startTime:
* start at
*
* endTime:
* end at
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### Account Information
```javascript
client.accountInfo()
```
### Account Trade List
```javascript
client.accountTradeList(symbol: string, options:any = { limit: 500 })
options:{ orderId, startTime, endTime, fromId, limit}
/**
*
* orderId:
* target orderId
*
* startTime:
* start at
*
* endTime:
* end at
*
* fromId:
* TradeId to fetch from. Default gets most recent trades
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```

View File

@ -1,188 +1,91 @@
# Crypto Trading Bot with Reinforcement Learning # Crypto Trading Bot with MEXC API Integration
An automated cryptocurrency trading bot that uses Deep Q-Learning (DQN) to trade ETH/USDT on the MEXC exchange. The bot features a sophisticated neural network architecture with LSTM layers and attention mechanisms for better pattern recognition. This is an AI-powered cryptocurrency trading bot that can run in both simulation (demo) mode and live trading mode using the MEXC exchange API.
## Features ## Features
- Deep Q-Learning with experience replay - Deep Reinforcement Learning agent for trading decisions
- LSTM layers for sequential data processing - Technical indicators and price prediction
- Multi-head attention mechanism - Live trading integration with MEXC exchange via mexc-api
- Dueling DQN architecture - Demo mode for testing without real trades
- Real-time trading capabilities - Real-time data streaming via websockets
- TensorBoard integration for monitoring - Performance tracking and visualization
- Comprehensive technical indicators
- Demo and live trading modes
- Automatic model checkpointing
## Prerequisites ## Setup
- Python 3.8+ 1. Clone the repository
- MEXC Exchange API credentials 2. Install dependencies:
- GPU recommended but not required
## Installation
1. Clone the repository:
```bash
git clone https://github.com/yourusername/crypto-trading-bot.git
cd crypto-trading-bot
``` ```
2. Create a virtual environment:
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
3. Install dependencies:
```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
3. Create a `.env` file in the root directory with your MEXC API keys:
4. Create a `.env` file in the project root with your MEXC API credentials:
```bash
MEXC_API_KEY=your_api_key
MEXC_API_SECRET=your_api_secret
cuda support
```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
``` ```
MEXC_API_KEY=your_api_key_here
MEXC_SECRET_KEY=your_secret_key_here
```
## Usage ## Usage
The bot can be run in three modes: The bot can be run in three different modes:
### Training Mode ### Training Mode
```bash Train the agent on historical data:
python main.py --mode train --episodes 1000
```
python main.py --mode train --episodes 100
``` ```
### Evaluation Mode ### Evaluation Mode
```bash Evaluate the trained agent on historical data:
python main.py --mode eval --episodes 10
```
python main.py --mode evaluate
``` ```
### Live Trading Mode ### Live Trading Mode
```bash Run the bot in live trading mode:
# Demo mode (simulated trading with real market data)
python main.py --mode live --demo
# Real trading (actual trades on MEXC) ```
python main.py --mode live python main.py --mode live
``` ```
Demo mode simulates trading using real-time market data but does not execute actual trades. It still: To run in demo mode (no real trades):
- Logs all trading decisions and performance metrics
- Updates the model based on market data (if in training mode)
- Displays real-time analytics and position information
- Calculates theoretical profits/losses
- Saves performance data to TensorBoard
This makes it perfect for testing strategies without financial risk. ```
python main.py --mode live --demo
```
## Live Trading Implementation
The bot uses the mexc-api package to execute trades on the MEXC exchange. The implementation includes:
- Market order execution for opening and closing positions
- Stop loss and take profit orders
- Real-time balance updates
- Trade history tracking
## Configuration ## Configuration
Key parameters can be adjusted in `main.py`: You can adjust the following parameters in `main.py`:
- `INITIAL_BALANCE`: Starting balance for training/demo - `INITIAL_BALANCE`: Starting balance for simulation
- `MAX_LEVERAGE`: Maximum leverage for trades - `MAX_LEVERAGE`: Leverage to use for trading
- `STOP_LOSS_PERCENT`: Stop loss percentage - `STOP_LOSS_PERCENT`: Default stop loss percentage
- `TAKE_PROFIT_PERCENT`: Take profit percentage - `TAKE_PROFIT_PERCENT`: Default take profit percentage
- `BATCH_SIZE`: Training batch size
- `LEARNING_RATE`: Model learning rate
- `STATE_SIZE`: Size of the state representation
## Model Architecture ## Architecture
The DQN model includes:
- Input layer with technical indicators
- LSTM layers for temporal pattern recognition
- Multi-head attention mechanism
- Dueling architecture for better Q-value estimation
- Batch normalization for stable training
## Monitoring
Training progress can be monitored using TensorBoard:
Training progress is logged to TensorBoard:
```bash
tensorboard --logdir=logs
```
This will show:
- Training rewards
- Account balance
- Win rate
- Loss metrics
## Trading Strategy
The bot makes decisions based on:
- Price action
- Technical indicators (RSI, MACD, Bollinger Bands, etc.)
- Historical patterns through LSTM
- Risk management with stop-loss and take-profit
## Safety Features
- Demo mode for safe testing
- Automatic stop-loss
- Position size limits
- Error handling for API calls
- Logging of all actions
## Directory Structure
├── main.py # Main bot implementation
├── requirements.txt # Project dependencies
├── .env # API credentials
├── models/ # Saved model checkpoints
├── runs/ # TensorBoard logs
└── trading_bot.log # Activity logs
- `main.py`: Main entry point and trading logic
- `mexc_trading.py`: MEXC API integration for live trading using mexc-api
- `models/`: Directory for saved model weights
## Warning ## Warning
Cryptocurrency trading carries significant risks. This bot is for educational purposes and should not be used with real money without thorough testing and understanding of the risks involved. Trading cryptocurrencies involves significant risk. This bot is provided for educational purposes only. Use at your own risk.
## License ## License
[MIT License](LICENSE) MIT
The main changes I made:
Fixed code block formatting by adding proper language identifiers
Added missing closing code blocks
Properly formatted directory structure
Added complete sections that were cut off in the original
Ensured consistent formatting throughout the document
Added proper bash syntax highlighting for command examples
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
# Edits/improvements
Fixes the shape mismatch by ensuring the state vector is exactly STATE_SIZE elements
Adds robust error handling in the model's forward pass to handle mismatched inputs
Adds a transformer encoder for more sophisticated pattern recognition
Provides an expand_model method to increase model capacity while preserving learned weights
Adds detailed logging about model size and shape mismatches
The model now has:
Configurable hidden layer sizes
Transformer layers for complex pattern recognition
LSTM layers for temporal patterns
Attention mechanisms for focusing on important features
Dueling architecture for better Q-value estimation
With hidden_size=256, this model has about 1-2 million parameters. By increasing hidden_size to 512 or 1024, you can easily scale to 5-20 million parameters. For even larger models (billions of parameters), you would need to implement a more distributed architecture with multiple GPUs, which would require significant changes to the training loop.

View File

@ -1,10 +1,11 @@
numpy>=1.21.0 numpy>=1.20.0
pandas>=1.3.0 pandas>=1.3.0
matplotlib>=3.4.0 matplotlib>=3.4.0
torch>=1.9.0 torch>=1.9.0
python-dotenv>=0.19.0 scikit-learn>=0.24.0
ccxt>=2.0.0 ccxt>=2.0.0
python-dotenv>=0.19.0
websockets>=10.0 websockets>=10.0
tensorboard>=2.6.0 tensorboard>=2.7.0
scikit-learn mexc-api>=1.0.0
mplfinance asyncio>=3.4.3

View File

@ -0,0 +1,23 @@
from mexc_api.spot import Spot
def test_mexc_api():
try:
# Initialize client with empty API keys for public endpoints
client = Spot("", "")
# Test server time endpoint
server_time = client.market.server_time()
print(f"Server time: {server_time}")
# Test ticker price endpoint
ticker = client.market.ticker_price("ETHUSDT")
print(f"ETH/USDT price: {ticker}")
print("MEXC API is working correctly!")
return True
except Exception as e:
print(f"Error testing MEXC API: {e}")
return False
if __name__ == "__main__":
test_mexc_api()

View File

@ -0,0 +1,34 @@
import asyncio
import os
from dotenv import load_dotenv
from mexc_trading import MexcTradingClient
# Load environment variables
load_dotenv()
async def test_trading_client():
"""Test the MexcTradingClient functionality"""
print("Initializing MexcTradingClient...")
client = MexcTradingClient(symbol="ETH/USDT")
# Test getting market price
print("Testing get_market_price...")
price = await client.get_market_price()
print(f"Current ETH/USDT price: {price}")
# If API keys are provided, test account balance
if os.getenv('MEXC_API_KEY') and os.getenv('MEXC_SECRET_KEY'):
print("Testing fetch_account_balance...")
balance = await client.fetch_account_balance()
print(f"Account balance: {balance} USDT")
print("Testing fetch_open_positions...")
positions = await client.fetch_open_positions()
print(f"Open positions: {positions}")
else:
print("No API keys provided. Skipping private endpoint tests.")
print("MexcTradingClient test completed!")
if __name__ == "__main__":
asyncio.run(test_trading_client())

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 170 KiB

After

Width:  |  Height:  |  Size: 307 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 86 KiB

After

Width:  |  Height:  |  Size: 84 KiB