Compare commits

...

16 Commits

Author SHA1 Message Date
Dobromir Popov
d5c291d15c more features, but dead-end 2025-03-17 00:29:50 +02:00
Dobromir Popov
2e901e18f2 improvments and fixes 2025-03-12 16:52:49 +02:00
Dobromir Popov
506458d55e wip enhanced multitimeframe model 2025-03-12 01:46:48 +02:00
Dobromir Popov
ad559d8c61 fixes, lots of new ideas 2025-03-12 00:56:32 +02:00
Dobromir Popov
d9f1bac11c lots of stufff 2025-03-12 00:08:02 +02:00
Dobromir Popov
4f6db5e86c Merge commit '3e924b32ac1737c859430186ddf6f288a1fba465' into march-trader-sonnet-3.7 2025-03-10 18:27:57 +02:00
Dobromir Popov
3e924b32ac models 2025-03-10 18:27:34 +02:00
Dobromir Popov
8dafb6d310 Merge commit '621a2505bd55db0f8295e5379638d7b1c7523620' 2025-03-10 16:43:19 +02:00
Dobromir Popov
621a2505bd models 2025-03-10 16:42:49 +02:00
Dobromir Popov
08b8da7c8f fix refactoring 2025-03-10 16:38:37 +02:00
Dobromir Popov
e884f0c9e6 added continious mode. fixed errors 2025-03-10 15:37:02 +02:00
Dobromir Popov
cfddc996d7 fixes 2025-03-10 14:53:21 +02:00
Dobromir Popov
6f78703ba1 plot charts 2025-03-10 14:48:29 +02:00
Dobromir Popov
715261a3f9 improvements 2025-03-10 13:32:35 +02:00
Dobromir Popov
2b1f00cbfc working on GPU 2025-03-10 13:15:30 +02:00
Dobromir Popov
643bc154a2 improvements 2025-03-10 13:07:07 +02:00
35 changed files with 217083 additions and 2118 deletions

8
.gitignore vendored
View File

@ -32,6 +32,12 @@ crypto/sol/.vs/*
crypto/brian/models/best/* crypto/brian/models/best/*
crypto/brian/models/last/* crypto/brian/models/last/*
crypto/brian/live_chart.html crypto/brian/live_chart.html
crypto/gogo2/models/*
crypto/gogo2/trading_bot.log crypto/gogo2/trading_bot.log
*.log *.log
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
*trading_agent_continuous_*.pt
*trading_agent_episode_*.pt
crypto/gogo2/models/trading_agent_continuous_*.pt
crypto/gogo2/visualizations/training_episode_*.png
crypto/gogo2/checkpoints/trading_agent_episode_*.pt

View File

@ -5,8 +5,8 @@
"name": "Train Bot", "name": "Train Bot",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main_multiu_broken.py",
"args": ["--mode", "train", "--episodes", "1000"], "args": ["--mode", "train", "--episodes", "10000"],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true
}, },
@ -14,7 +14,7 @@
"name": "Evaluate Bot", "name": "Evaluate Bot",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main_multiu_broken.py",
"args": ["--mode", "eval", "--episodes", "10"], "args": ["--mode", "eval", "--episodes", "10"],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true
@ -23,7 +23,7 @@
"name": "Live Trading (Demo)", "name": "Live Trading (Demo)",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main_multiu_broken.py",
"args": ["--mode", "live", "--demo"], "args": ["--mode", "live", "--demo"],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true
@ -32,10 +32,19 @@
"name": "Live Trading (Real)", "name": "Live Trading (Real)",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main_multiu_broken.py",
"args": ["--mode", "live"], "args": ["--mode", "live"],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true
},
{
"name": "Continuous Training",
"type": "python",
"request": "launch",
"program": "main_multiu_broken.py",
"args": ["--mode", "continuous", "--refresh-data"],
"console": "integratedTerminal",
"justMyCode": true
} }
] ]
} }

View File

@ -0,0 +1,42 @@
def step(self, action):
"""Take an action in the environment and return the next state, reward, and done flag"""
# Store current price before taking action
self.current_price = self.data[self.current_step]['close']
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
if not self.demo and self.trading_client:
# Execute real trades in live mode
asyncio.create_task(self._execute_live_action(action))
# Calculate reward (simulation still runs in parallel with live trading)
reward, _ = self.calculate_reward(action) # Unpack the tuple here
# Check for stop loss / take profit hits
self.check_sl_tp()
# Move to next step
self.current_step += 1
done = self.current_step >= len(self.data) - 1
# Get new state
next_state = self.get_state()
return next_state, reward, done
def calculate_reward(self, action):
"""Calculate the reward for the current action."""
# ... (existing code)
# Combine all reward components
reward = pnl_reward + timing_reward + risk_reward + prediction_reward
# Log components for analysis
info = {
'pnl_reward': pnl_reward,
'timing_reward': timing_reward,
'risk_reward': risk_reward,
'prediction_reward': prediction_reward,
'total_reward': reward
}
return reward # Return only the reward, not the info dictionary

67
crypto/gogo2/_model.md Normal file
View File

@ -0,0 +1,67 @@
# Neural Network Architecture Analysis for Trading Bot
## Overview
This document provides a comprehensive analysis of the neural network architecture used in our trading bot system. The system consists of two main neural network components:
1. **Price Prediction Model** - Forecasts future price movements and extrema points
2. **DQN (Deep Q-Network)** - Makes trading decisions based on state representations
## 1. Price Prediction Model
### Architecture
```
PricePredictionModel(nn.Module)
├── Input Layer: [batch_size, seq_len, 2] (price, volume)
├── LSTM Layers: 2 stacked layers with hidden_size=128
├── Attention Mechanism: Self-attention with linear projections
├── Linear Layer 1: hidden_size → hidden_size
├── ReLU Activation
├── Linear Layer 2: hidden_size → output_size (5 future prices)
└── Output: [batch_size, output_size]
```
### Data Flow
**Inputs:**
- `price_history`: Sequence of historical prices [batch_size, seq_len]
- `volume_history`: Sequence of historical volumes [batch_size, seq_len]
**Preprocessing:**
- Normalization using MinMaxScaler (0-1 range)
- Reshaping to [batch_size, seq_len, 2] (price and volume features)
**Forward Pass:**
1. Input data passes through LSTM layers
2. Self-attention mechanism applied to LSTM outputs
3. Linear layers process the attended features
4. Output represents predicted prices for next 5 candles
**Outputs:**
- `predicted_prices`: Array of 5 future price predictions
- `predicted_extrema`: Binary indicators for potential price extrema points
## 2. DQN (Deep Q-Network)
### Architecture
```
DQN(nn.Module)
├── Input Layer: [batch_size, state_size]
├── Linear Layer 1: state_size → hidden_size (384)
├── ReLU Activation
├── LSTM Layers: 2 stacked layers with hidden_size=384
├── Multi-Head Attention: 4 attention heads
├── Linear Layer 2: hidden_size → hidden_size
├── ReLU Activation
├── Linear Layer 3: hidden_size → action_size (4)
└── Output: [batch_size, action_size] (Q-values for each action)
```
### Data Flow
**Inputs:**
- `state`: Current market state representation [batch_size, state_size]
- Price features (normalized prices, returns, volatility)
- Technical indicators (RSI, MACD, Stochastic,

View File

@ -1,9 +1,17 @@
pip install torch-tb-profiler
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function. ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
do we call the trading api proerly when in live mode? in live mode - also validate the current ballance. also ensure trades are executed after executing the orders by checking the open orders.
our trading data chart (in tensorboard) does not properly displayed - the candles seems displayed multiple times but shifted in time. we also do not correctly show the buy/sell evens on the time axis. we do not show the predicted price on the chart.
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles 2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance) C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
@ -17,3 +25,41 @@ C:\Users\popov\miniforge3\Lib\site-packages\torch\amp\grad_scaler.py:132: UserWa
2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor' 2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor'
2025-03-10 12:11:30,928 - INFO - Exchange connection closed 2025-03-10 12:11:30,928 - INFO - Exchange connection closed
Backend tkagg is interactive backend. Turning interactive mode on. Backend tkagg is interactive backend. Turning interactive mode on.
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
----------------
do we train the price prediction model by using old known candles and masking the latest to make him guess the next and then backpropagating the next already known candle ? like a transformer (gpt2) would do? or we use RL for that as well?
it seems the model is not learning a lot. we keep hovering about the same starting balance even after some time in training in continious mode
it seems we may need another NN model down the loop jut to predict the extremums of the price.
we may have to include a mechanism to calculate the extremums of the price retrospectively and to use that to bootstrap pre-train the model.
Why Performance Might Be Stagnating
Several factors could explain why the model isn't improving significantly during training:
Insufficient Model Capacity for Price Prediction: While the price prediction model has 267,663 parameters, financial time series prediction is extremely challenging. The market may have patterns that are too complex or too random for the current model to capture effectively.
Overfitting to Training Data: The model might be memorizing patterns in the training data that don't generalize to new market conditions.
Transformer-LSTM Redundancy in DQN: Your DQN model uses both a transformer and an LSTM, which might be redundant. Both are designed to capture sequential dependencies, and having both could lead to overfitting or training instability.
Imbalanced Parameter Distribution: 64.5% of your DQN parameters are in the transformer component, which might be excessive for the task.
Reward Function Issues: The reward function might not be properly aligned with profitable trading strategies, or it might be too sparse to provide meaningful learning signals.
Suggested Improvements
1. Enhance Price Prediction Training
2. Simplify the DQN Architecture
Consider creating a more streamlined DQN model:
3. Improve the Reward Function
Make sure your reward function provides meaningful signals for learning:
4. Implement Curriculum Learning
Start with simpler market conditions and gradually increase complexity:
Conclusion
The issue appears to be a combination of model complexity, potential overfitting, and possibly insufficient learning signals from the reward function. By simplifying the DQN architecture (particularly reducing the transformer component), improving the price prediction training, and enhancing the reward function, you should see better learning progress.
Would you like me to implement any of these specific improvements to your codebase?

362
crypto/gogo2/archive.py Normal file
View File

@ -0,0 +1,362 @@
class MexcTradingClient:
def __init__(self, api_key, secret_key, symbol="ETH/USDT", leverage=50):
self.client = ccxt.mexc({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True,
})
self.symbol = symbol
self.leverage = leverage
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.stop_loss = 0
self.take_profit = 0
self.trades = []
def initialize_mexc_client(self, api_key, api_secret):
"""Initialize the MEXC API client"""
try:
from mexc_sdk import Spot
self.mexc_client = Spot(api_key=api_key, api_secret=api_secret)
# Test connection
self.mexc_client.ping()
logger.info("MEXC API client initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize MEXC API client: {e}")
logger.error(traceback.format_exc())
return False
async def fetch_account_balance(self):
"""Fetch actual account balance from MEXC API"""
if self.demo or not self.mexc_client:
# In demo mode, use simulated balance
return self.balance
try:
account_info = self.mexc_client.accountInfo()
if 'balances' in account_info:
# Find USDT balance
for asset in account_info['balances']:
if asset['asset'] == 'USDT':
return float(asset['free'])
logger.warning("Could not find USDT balance, using current simulated balance")
return self.balance
except Exception as e:
logger.error(f"Error fetching account balance: {e}")
logger.error(traceback.format_exc())
# Fallback to simulated balance in case of API error
return self.balance
async def fetch_open_positions(self):
"""Fetch actual open positions from MEXC API"""
if self.demo or not self.mexc_client:
# In demo mode, return current simulated position
return [{
'symbol': 'ETH/USDT',
'positionSide': 'LONG' if self.position == 'long' else 'SHORT' if self.position == 'short' else 'NONE',
'positionAmt': self.position_size / self.current_price if self.position != 'flat' else 0,
'entryPrice': self.entry_price,
'unrealizedProfit': self.calculate_unrealized_pnl()
}] if self.position != 'flat' else []
try:
# Fetch open positions
positions = self.mexc_client.openOrders('ETH/USDT')
return positions
except Exception as e:
logger.error(f"Error fetching open positions: {e}")
logger.error(traceback.format_exc())
# Fallback to simulated positions in case of API error
return []
def calculate_unrealized_pnl(self):
"""Calculate unrealized PnL for the current position"""
if self.position == 'flat':
return 0.0
position_value = self.position_size / self.entry_price
if self.position == 'long':
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
else: # short
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
# Apply leverage
pnl_percent *= self.leverage
return position_value * pnl_percent / 100
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
"""Open a new position using MEXC API in live mode, or simulate in demo mode"""
if self.demo or not self.mexc_client:
# In demo mode, simulate opening a position
self.position = position_type
self.position_size = size
self.entry_price = entry_price
self.entry_index = self.current_step
self.stop_loss = stop_loss
self.take_profit = take_profit
logger.info(f"DEMO: Opened {position_type.upper()} position at {entry_price} | " +
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
return True
try:
# In live mode, place actual orders via API
symbol = "ETHUSDT" # Format required by MEXC
side = "BUY" if position_type == 'long' else "SELL"
# Calculate quantity based on size and price
quantity = size / entry_price
# Place main order
order_result = self.mexc_client.newOrder(
symbol=symbol,
side=side,
orderType="MARKET",
quantity=quantity,
options={
"leverage": self.leverage,
"newOrderRespType": "FULL"
}
)
# Check if order executed
if order_result.get('status') == 'FILLED':
actual_entry_price = float(order_result.get('price', entry_price))
# Place stop loss order
sl_order = self.mexc_client.newOrder(
symbol=symbol,
side="SELL" if position_type == 'long' else "BUY",
orderType="STOP_LOSS",
quantity=quantity,
options={
"stopPrice": stop_loss,
"newOrderRespType": "ACK"
}
)
# Place take profit order
tp_order = self.mexc_client.newOrder(
symbol=symbol,
side="SELL" if position_type == 'long' else "BUY",
orderType="TAKE_PROFIT",
quantity=quantity,
options={
"stopPrice": take_profit,
"newOrderRespType": "ACK"
}
)
# Update local state
self.position = position_type
self.position_size = size
self.entry_price = actual_entry_price
self.entry_index = self.current_step
self.stop_loss = stop_loss
self.take_profit = take_profit
# Track orders
self.open_orders.extend([sl_order, tp_order])
self.order_history.append(order_result)
logger.info(f"LIVE: Opened {position_type.upper()} position at {actual_entry_price} | " +
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
return True
else:
logger.error(f"Failed to execute order: {order_result}")
return False
except Exception as e:
logger.error(f"Error opening position: {e}")
logger.error(traceback.format_exc())
return False
async def close_position(self, reason="manual_close"):
"""Close the current position using MEXC API in live mode, or simulate in demo mode"""
if self.position == 'flat':
logger.info("No position to close")
return False
if self.demo or not self.mexc_client:
# In demo mode, simulate closing a position
position_type = self.position
entry_price = self.entry_price
exit_price = self.current_price
position_size = self.position_size
# Calculate PnL
if position_type == 'long':
pnl_percent = (exit_price - entry_price) / entry_price * 100
else: # short
pnl_percent = (entry_price - exit_price) / entry_price * 100
# Apply leverage
pnl_percent *= self.leverage
# Calculate actual PnL
pnl_dollar = position_size * pnl_percent / 100
# Apply fees
pnl_dollar -= self.calculate_fees(position_size)
# Update balance
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': position_type,
'entry': entry_price,
'exit': exit_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': reason,
'leverage': self.leverage
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
logger.info(f"DEMO: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
return True
try:
# In live mode, close position via API
symbol = "ETHUSDT"
position_info = await self.fetch_open_positions()
if not position_info:
logger.warning("No open positions found to close")
self.position = 'flat'
return False
# First, cancel any existing stop loss/take profit orders
try:
self.mexc_client.cancelOpenOrders(symbol)
except Exception as e:
logger.warning(f"Error canceling open orders: {e}")
# Close the position with a market order
position_type = self.position
side = "SELL" if position_type == 'long' else "BUY"
quantity = self.position_size / self.current_price
# Execute order
order_result = self.mexc_client.newOrder(
symbol=symbol,
side=side,
orderType="MARKET",
quantity=quantity,
options={
"newOrderRespType": "FULL"
}
)
# Check if order executed
if order_result.get('status') == 'FILLED':
exit_price = float(order_result.get('price', self.current_price))
entry_price = self.entry_price
position_size = self.position_size
# Calculate PnL
if position_type == 'long':
pnl_percent = (exit_price - entry_price) / entry_price * 100
else: # short
pnl_percent = (entry_price - exit_price) / entry_price * 100
# Apply leverage
pnl_percent *= self.leverage
# Calculate actual PnL
pnl_dollar = position_size * pnl_percent / 100
# Apply fees
pnl_dollar -= self.calculate_fees(position_size)
# Update balance from API
self.balance = await self.fetch_account_balance()
self.total_pnl += pnl_dollar
self.episode_pnl += pnl_dollar
# Update max drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
drawdown = (self.peak_balance - self.balance) / self.peak_balance
self.max_drawdown = max(self.max_drawdown, drawdown)
# Record trade
self.trades.append({
'type': position_type,
'entry': entry_price,
'exit': exit_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index,
'market_direction': self.get_market_direction(),
'reason': reason,
'leverage': self.leverage,
'order_id': order_result.get('orderId')
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Track order history
self.order_history.append(order_result)
logger.info(f"LIVE: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.entry_index = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
return True
else:
logger.error(f"Failed to close position: {order_result}")
return False
except Exception as e:
logger.error(f"Error closing position: {e}")
logger.error(traceback.format_exc())
return False

View File

@ -0,0 +1 @@
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 0, "timestamp": "2025-03-12T00:23:19.125190"}

View File

@ -0,0 +1,378 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class EnhancedPricePredictionModel(nn.Module):
def __init__(self, input_dim=2, hidden_dim=256, num_layers=3, output_dim=5, num_timeframes=3):
super(EnhancedPricePredictionModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_timeframes = num_timeframes
# Separate LSTM for each timeframe
self.timeframe_lstms = nn.ModuleList([
nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)
for _ in range(num_timeframes)
])
# Cross-timeframe attention
self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
# Self-attention for each timeframe
self.self_attentions = nn.ModuleList([
nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
for _ in range(num_timeframes)
])
# Timeframe fusion layer
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_dim * num_timeframes, hidden_dim * 2),
nn.LeakyReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim * 2, hidden_dim)
)
# Fully connected layer for price prediction
self.price_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim)
)
# Fully connected layer for extrema prediction (high and low points)
self.extrema_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
)
# Volume prediction layer
self.volume_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x_list):
# x_list is a list of tensors, one for each timeframe
# Each x shape: (batch_size, seq_len, input_dim)
# Process each timeframe with its own LSTM
lstm_outputs = []
for i, x in enumerate(x_list):
lstm_out, _ = self.timeframe_lstms[i](x) # lstm_out: (batch_size, seq_len, hidden_dim)
lstm_outputs.append(lstm_out)
# Apply self-attention to each timeframe
attn_outputs = []
for i, lstm_out in enumerate(lstm_outputs):
attn_output, _ = self.self_attentions[i](lstm_out, lstm_out, lstm_out)
attn_outputs.append(attn_output[:, -1, :]) # Use the last time step
# Concatenate all timeframe representations
combined = torch.cat(attn_outputs, dim=1) # (batch_size, hidden_dim * num_timeframes)
# Fuse timeframe information
fused = self.fusion_layer(combined) # (batch_size, hidden_dim)
# Price prediction
price_pred = self.price_fc(fused)
# Extrema prediction
extrema_logits = self.extrema_fc(fused)
# Volume prediction
volume_pred = self.volume_fc(fused)
return price_pred, extrema_logits, volume_pred
class EnhancedDQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=512):
super(EnhancedDQN, self).__init__()
# Feature extraction layers with increased capacity
self.feature_extraction = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
)
# Advantage stream with increased capacity
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, action_dim)
)
# Value stream with increased capacity
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, 1)
)
# Enhanced transformer for temporal dependencies
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=hidden_dim*4,
dropout=0.1,
batch_first=True
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=3)
# LSTM for sequential decision making with increased capacity
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.1)
# Final layers with increased capacity
self.final_layers = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, action_dim)
)
# Market regime classification layer
self.market_regime_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, 3) # 3 regimes: trending, ranging, volatile
)
def forward(self, state, hidden=None):
# Extract features
features = self.feature_extraction(state)
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
# Transformer processing
transformer_out = self.transformer(features)
# LSTM processing
lstm_out, lstm_hidden = self.lstm(transformer_out)
# Dueling architecture
advantage = self.advantage_stream(features.squeeze(1))
value = self.value_stream(features.squeeze(1))
# Combine transformer, LSTM and dueling outputs
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
q_values = self.final_layers(combined)
# Market regime classification
market_regime = self.market_regime_classifier(transformer_out.squeeze(1))
# Dueling Q-value computation
dueling_q = value + advantage - advantage.mean(dim=1, keepdim=True)
# Final Q-values are a weighted combination of the dueling Q-values and the direct Q-values
# This allows the model to use either approach depending on the situation
q_values = 0.5 * dueling_q + 0.5 * q_values
return q_values, lstm_hidden, market_regime
def count_parameters(model):
total_params = 0
layer_params = {}
for name, param in model.named_parameters():
if param.requires_grad:
param_count = param.numel()
total_params += param_count
layer_params[name] = (param_count, param.shape)
return total_params, layer_params
def main():
# Initialize the original models for comparison
original_price_model = PricePredictionModel()
original_price_total_params, _ = count_parameters(original_price_model)
state_dim = 50
action_dim = 3
original_dqn_model = DQN(state_dim=state_dim, action_dim=action_dim)
original_dqn_total_params, _ = count_parameters(original_dqn_model)
# Initialize the enhanced models
enhanced_price_model = EnhancedPricePredictionModel(num_timeframes=3)
enhanced_price_total_params, enhanced_price_layer_params = count_parameters(enhanced_price_model)
# Increase state dimension to accommodate multiple timeframes
enhanced_state_dim = 100 # Increased from 50 to accommodate more features
enhanced_dqn_model = EnhancedDQN(state_dim=enhanced_state_dim, action_dim=action_dim)
enhanced_dqn_total_params, enhanced_dqn_layer_params = count_parameters(enhanced_dqn_model)
# Print comparison
print("=== MODEL SIZE COMPARISON ===")
print(f"Original Price Prediction Model: {original_price_total_params:,} parameters")
print(f"Enhanced Price Prediction Model: {enhanced_price_total_params:,} parameters")
print(f"Growth Factor: {enhanced_price_total_params / original_price_total_params:.2f}x\n")
print(f"Original DQN Model: {original_dqn_total_params:,} parameters")
print(f"Enhanced DQN Model: {enhanced_dqn_total_params:,} parameters")
print(f"Growth Factor: {enhanced_dqn_total_params / original_dqn_total_params:.2f}x\n")
print(f"Total Original Models: {original_price_total_params + original_dqn_total_params:,} parameters")
print(f"Total Enhanced Models: {enhanced_price_total_params + enhanced_dqn_total_params:,} parameters")
print(f"Overall Growth Factor: {(enhanced_price_total_params + enhanced_dqn_total_params) / (original_price_total_params + original_dqn_total_params):.2f}x\n")
# Print VRAM usage estimate (rough approximation)
bytes_per_param = 4 # 4 bytes for float32
original_vram_mb = (original_price_total_params + original_dqn_total_params) * bytes_per_param / (1024 * 1024)
enhanced_vram_mb = (enhanced_price_total_params + enhanced_dqn_total_params) * bytes_per_param / (1024 * 1024)
print("=== ESTIMATED VRAM USAGE ===")
print(f"Original Models: {original_vram_mb:.2f} MB")
print(f"Enhanced Models: {enhanced_vram_mb:.2f} MB")
print(f"Available VRAM: 8,192 MB (8 GB)")
print(f"VRAM Utilization: {enhanced_vram_mb / 8192 * 100:.2f}%\n")
# Print detailed breakdown of enhanced models
print("=== ENHANCED PRICE PREDICTION MODEL BREAKDOWN ===")
# Group parameters by component
timeframe_lstm_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if "timeframe_lstms" in name)
attention_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if "attention" in name)
fusion_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if "fusion" in name)
output_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if any(x in name for x in ["price_fc", "extrema_fc", "volume_fc"]))
print(f"Timeframe LSTMs: {timeframe_lstm_params:,} parameters ({timeframe_lstm_params/enhanced_price_total_params*100:.1f}%)")
print(f"Attention Mechanisms: {attention_params:,} parameters ({attention_params/enhanced_price_total_params*100:.1f}%)")
print(f"Fusion Layer: {fusion_params:,} parameters ({fusion_params/enhanced_price_total_params*100:.1f}%)")
print(f"Output Layers: {output_params:,} parameters ({output_params/enhanced_price_total_params*100:.1f}%)\n")
print("=== ENHANCED DQN MODEL BREAKDOWN ===")
# Group parameters by component
feature_extraction_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "feature_extraction" in name)
advantage_value_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "advantage_stream" in name or "value_stream" in name)
transformer_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "transformer" in name)
lstm_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "lstm" in name and "transformer" not in name)
final_layers_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "final_layers" in name)
market_regime_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "market_regime" in name)
print(f"Feature Extraction: {feature_extraction_params:,} parameters ({feature_extraction_params/enhanced_dqn_total_params*100:.1f}%)")
print(f"Advantage & Value Streams: {advantage_value_params:,} parameters ({advantage_value_params/enhanced_dqn_total_params*100:.1f}%)")
print(f"Transformer: {transformer_params:,} parameters ({transformer_params/enhanced_dqn_total_params*100:.1f}%)")
print(f"LSTM: {lstm_params:,} parameters ({lstm_params/enhanced_dqn_total_params*100:.1f}%)")
print(f"Final Layers: {final_layers_params:,} parameters ({final_layers_params/enhanced_dqn_total_params*100:.1f}%)")
print(f"Market Regime Classifier: {market_regime_params:,} parameters ({market_regime_params/enhanced_dqn_total_params*100:.1f}%)")
# Keep the original models for comparison
class PricePredictionModel(nn.Module):
def __init__(self, input_dim=2, hidden_dim=128, num_layers=2, output_dim=5):
super(PricePredictionModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# LSTM layers
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
# Self-attention mechanism
self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
# Fully connected layer for price prediction
self.price_fc = nn.Linear(hidden_dim, output_dim)
# Fully connected layer for extrema prediction (high and low points)
self.extrema_fc = nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
def forward(self, x):
# x shape: (batch_size, seq_len, input_dim)
# LSTM forward pass
lstm_out, _ = self.lstm(x) # lstm_out: (batch_size, seq_len, hidden_dim)
# Self-attention
attn_output, _ = self.attention(lstm_out, lstm_out, lstm_out)
# Price prediction
price_pred = self.price_fc(attn_output[:, -1, :]) # Use the last time step
# Extrema prediction
extrema_logits = self.extrema_fc(attn_output[:, -1, :])
return price_pred, extrema_logits
class DQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(DQN, self).__init__()
# Feature extraction layers
self.feature_extraction = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
)
# Advantage stream
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, action_dim)
)
# Value stream
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, 1)
)
# Transformer for temporal dependencies
encoder_layers = TransformerEncoderLayer(d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
# LSTM for sequential decision making
self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
# Final layers
self.final_layers = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state, hidden=None):
# Extract features
features = self.feature_extraction(state)
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
# Transformer processing
transformer_out = self.transformer(features)
# LSTM processing
lstm_out, lstm_hidden = self.lstm(transformer_out)
# Dueling architecture
advantage = self.advantage_stream(features.squeeze(1))
value = self.value_stream(features.squeeze(1))
# Combine transformer, LSTM and dueling outputs
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
q_values = self.final_layers(combined)
# Dueling Q-value computation
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, lstm_hidden
if __name__ == "__main__":
main()

4
crypto/gogo2/cuda.py Normal file
View File

@ -0,0 +1,4 @@
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'Not available'}")

319
crypto/gogo2/data_cache.py Normal file
View File

@ -0,0 +1,319 @@
import os
import json
import time
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import logging
# Set up logging
logger = logging.getLogger('trading_bot')
class OHLCVCache:
"""
A simple cache for OHLCV data from exchanges.
Stores data in a structured format and provides backup when exchange is unavailable.
"""
def __init__(self, cache_dir="cache", max_age_hours=24):
"""
Initialize the OHLCV cache.
Args:
cache_dir: Directory to store cache files
max_age_hours: Maximum age of cached data in hours before considered stale
"""
self.cache_dir = cache_dir
self.max_age_seconds = max_age_hours * 3600
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# In-memory cache for faster access
self.memory_cache = {}
def _get_cache_filename(self, symbol, timeframe):
"""Generate a standardized filename for the cache file"""
# Replace / with _ in symbol name (e.g., ETH/USDT -> ETH_USDT)
safe_symbol = symbol.replace('/', '_')
return os.path.join(self.cache_dir, f"{safe_symbol}_{timeframe}.json")
def save(self, data, symbol, timeframe):
"""
Save OHLCV data to cache.
Args:
data: List of dictionaries containing OHLCV data
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
"""
if not data:
logger.warning(f"No data to cache for {symbol} ({timeframe})")
return False
try:
# Convert data to a serializable format
serializable_data = []
for candle in data:
serializable_data.append({
'timestamp': candle['timestamp'],
'open': float(candle['open']),
'high': float(candle['high']),
'low': float(candle['low']),
'close': float(candle['close']),
'volume': float(candle['volume'])
})
# Create cache entry with metadata
cache_entry = {
'symbol': symbol,
'timeframe': timeframe,
'last_updated': int(time.time()),
'data': serializable_data
}
# Save to file
filename = self._get_cache_filename(symbol, timeframe)
with open(filename, 'w') as f:
json.dump(cache_entry, f)
# Update in-memory cache
cache_key = f"{symbol}_{timeframe}"
self.memory_cache[cache_key] = cache_entry
logger.info(f"Cached {len(data)} candles for {symbol} ({timeframe})")
return True
except Exception as e:
logger.error(f"Error saving data to cache: {e}")
return False
def load(self, symbol, timeframe, max_age_override=None):
"""
Load OHLCV data from cache.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
max_age_override: Override the default max age (in seconds)
Returns:
List of dictionaries containing OHLCV data, or None if cache is missing or stale
"""
cache_key = f"{symbol}_{timeframe}"
max_age = max_age_override if max_age_override is not None else self.max_age_seconds
try:
# Check in-memory cache first
if cache_key in self.memory_cache:
cache_entry = self.memory_cache[cache_key]
# Check if cache is fresh
cache_age = int(time.time()) - cache_entry['last_updated']
if cache_age <= max_age:
logger.info(f"Using in-memory cache for {symbol} ({timeframe}), age: {cache_age//60} minutes")
return cache_entry['data']
# Check file cache
filename = self._get_cache_filename(symbol, timeframe)
if not os.path.exists(filename):
logger.info(f"No cache file found for {symbol} ({timeframe})")
return None
# Load cache file
with open(filename, 'r') as f:
cache_entry = json.load(f)
# Check if cache is fresh
cache_age = int(time.time()) - cache_entry['last_updated']
if cache_age > max_age:
logger.info(f"Cache for {symbol} ({timeframe}) is stale ({cache_age//60} minutes old)")
return None
# Update in-memory cache
self.memory_cache[cache_key] = cache_entry
logger.info(f"Loaded {len(cache_entry['data'])} candles from cache for {symbol} ({timeframe})")
return cache_entry['data']
except Exception as e:
logger.error(f"Error loading data from cache: {e}")
return None
def append(self, new_candle, symbol, timeframe):
"""
Append a new candle to the cached data.
Args:
new_candle: Dictionary containing a single OHLCV candle
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
Returns:
Boolean indicating success
"""
try:
# Load existing data
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for append
if data is None:
data = []
# Check if the candle already exists (same timestamp)
for i, candle in enumerate(data):
if candle['timestamp'] == new_candle['timestamp']:
# Update existing candle
data[i] = {
'timestamp': new_candle['timestamp'],
'open': float(new_candle['open']),
'high': float(new_candle['high']),
'low': float(new_candle['low']),
'close': float(new_candle['close']),
'volume': float(new_candle['volume'])
}
# Save updated data
return self.save(data, symbol, timeframe)
# Append new candle
data.append({
'timestamp': new_candle['timestamp'],
'open': float(new_candle['open']),
'high': float(new_candle['high']),
'low': float(new_candle['low']),
'close': float(new_candle['close']),
'volume': float(new_candle['volume'])
})
# Save updated data
return self.save(data, symbol, timeframe)
except Exception as e:
logger.error(f"Error appending candle to cache: {e}")
return False
def get_latest_timestamp(self, symbol, timeframe):
"""
Get the timestamp of the most recent candle in the cache.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
Returns:
Timestamp (milliseconds) of the most recent candle, or None if cache is empty
"""
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for this check
if not data:
return None
# Find the most recent timestamp
latest_timestamp = max(candle['timestamp'] for candle in data)
return latest_timestamp
def clear(self, symbol=None, timeframe=None):
"""
Clear cache for a specific symbol and timeframe, or all cache if not specified.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT'), or None to clear all symbols
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h'), or None to clear all timeframes
Returns:
Number of cache files deleted
"""
count = 0
try:
if symbol and timeframe:
# Clear specific cache
filename = self._get_cache_filename(symbol, timeframe)
if os.path.exists(filename):
os.remove(filename)
count = 1
# Clear from memory cache
cache_key = f"{symbol}_{timeframe}"
if cache_key in self.memory_cache:
del self.memory_cache[cache_key]
else:
# Clear all matching caches
for filename in os.listdir(self.cache_dir):
file_path = os.path.join(self.cache_dir, filename)
# Skip directories
if not os.path.isfile(file_path):
continue
# Check if file matches the filter
should_delete = True
if symbol:
safe_symbol = symbol.replace('/', '_')
if not filename.startswith(f"{safe_symbol}_"):
should_delete = False
if timeframe:
if not filename.endswith(f"_{timeframe}.json"):
should_delete = False
# Delete file if it matches the filter
if should_delete:
os.remove(file_path)
count += 1
# Clear memory cache
keys_to_delete = []
for cache_key in self.memory_cache:
should_delete = True
if symbol:
if not cache_key.startswith(f"{symbol}_"):
should_delete = False
if timeframe:
if not cache_key.endswith(f"_{timeframe}"):
should_delete = False
if should_delete:
keys_to_delete.append(cache_key)
for key in keys_to_delete:
del self.memory_cache[key]
logger.info(f"Cleared {count} cache files")
return count
except Exception as e:
logger.error(f"Error clearing cache: {e}")
return 0
def to_dataframe(self, symbol, timeframe):
"""
Convert cached OHLCV data to a pandas DataFrame.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
Returns:
pandas DataFrame with OHLCV data, or None if cache is missing
"""
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for conversion
if not data:
return None
# Convert to DataFrame
df = pd.DataFrame(data)
# Convert timestamp to datetime
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
# Set datetime as index
df.set_index('datetime', inplace=True)
return df
# Create a global instance for easy access
ohlcv_cache = OHLCVCache()

View File

@ -0,0 +1,449 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class EnhancedPricePredictionModel(nn.Module):
def __init__(self, input_dim=2, hidden_dim=256, num_layers=3, output_dim=5, num_timeframes=3):
super(EnhancedPricePredictionModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_timeframes = num_timeframes
# Separate LSTM for each timeframe
self.timeframe_lstms = nn.ModuleList([
nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)
for _ in range(num_timeframes)
])
# Cross-timeframe attention
self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
# Self-attention for each timeframe
self.self_attentions = nn.ModuleList([
nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
for _ in range(num_timeframes)
])
# Timeframe fusion layer
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_dim * num_timeframes, hidden_dim * 2),
nn.LeakyReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim * 2, hidden_dim)
)
# Fully connected layer for price prediction
self.price_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim)
)
# Fully connected layer for extrema prediction (high and low points)
self.extrema_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
)
# Volume prediction layer
self.volume_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x_list):
# x_list is a list of tensors, one for each timeframe
# Each x shape: (batch_size, seq_len, input_dim)
# Process each timeframe with its own LSTM
lstm_outputs = []
for i, x in enumerate(x_list):
lstm_out, _ = self.timeframe_lstms[i](x) # lstm_out: (batch_size, seq_len, hidden_dim)
lstm_outputs.append(lstm_out)
# Apply self-attention to each timeframe
attn_outputs = []
for i, lstm_out in enumerate(lstm_outputs):
attn_output, _ = self.self_attentions[i](lstm_out, lstm_out, lstm_out)
attn_outputs.append(attn_output[:, -1, :]) # Use the last time step
# Concatenate all timeframe representations
combined = torch.cat(attn_outputs, dim=1) # (batch_size, hidden_dim * num_timeframes)
# Fuse timeframe information
fused = self.fusion_layer(combined) # (batch_size, hidden_dim)
# Price prediction
price_pred = self.price_fc(fused)
# Extrema prediction
extrema_logits = self.extrema_fc(fused)
# Volume prediction
volume_pred = self.volume_fc(fused)
return price_pred, extrema_logits, volume_pred
class EnhancedDQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=512):
super(EnhancedDQN, self).__init__()
# Feature extraction layers with increased capacity
self.feature_extraction = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
)
# Advantage stream with increased capacity
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, action_dim)
)
# Value stream with increased capacity
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, 1)
)
# Enhanced transformer for temporal dependencies
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=hidden_dim*4,
dropout=0.1,
batch_first=True
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=3)
# LSTM for sequential decision making with increased capacity
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.1)
# Final layers with increased capacity
self.final_layers = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.LeakyReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, action_dim)
)
# Market regime classification layer
self.market_regime_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LeakyReLU(),
nn.Linear(hidden_dim // 2, 3) # 3 regimes: trending, ranging, volatile
)
def forward(self, state, hidden=None):
# Extract features
features = self.feature_extraction(state)
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
# Transformer processing
transformer_out = self.transformer(features)
# LSTM processing
lstm_out, lstm_hidden = self.lstm(transformer_out)
# Dueling architecture
advantage = self.advantage_stream(features.squeeze(1))
value = self.value_stream(features.squeeze(1))
# Combine transformer, LSTM and dueling outputs
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
q_values = self.final_layers(combined)
# Market regime classification
market_regime = self.market_regime_classifier(transformer_out.squeeze(1))
# Dueling Q-value computation
dueling_q = value + advantage - advantage.mean(dim=1, keepdim=True)
# Final Q-values are a weighted combination of the dueling Q-values and the direct Q-values
# This allows the model to use either approach depending on the situation
q_values = 0.5 * dueling_q + 0.5 * q_values
return q_values, lstm_hidden, market_regime
class EnhancedReplayBuffer:
"""Enhanced replay buffer with prioritized experience replay and n-step returns"""
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, n_step=3, gamma=0.99):
self.capacity = capacity
self.buffer = []
self.position = 0
self.priorities = torch.zeros(capacity)
self.alpha = alpha # Priority exponent
self.beta = beta # Importance sampling weight
self.beta_increment = beta_increment # Beta annealing
self.n_step = n_step # n-step returns
self.gamma = gamma # Discount factor
self.n_step_buffer = []
self.max_priority = 1.0
def add(self, state, action, reward, next_state, done):
"""
Add a new experience to the buffer (simplified version of push for compatibility)
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
done: Whether the episode is done
"""
# Store in replay buffer with max priority
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
# Set priority to max priority to ensure it gets sampled
self.priorities[self.position] = self.max_priority
# Move position pointer
self.position = (self.position + 1) % self.capacity
def push(self, state, action, reward, next_state, done):
# Store experience in n-step buffer
self.n_step_buffer.append((state, action, reward, next_state, done))
# If we don't have enough experiences for n-step return, wait
if len(self.n_step_buffer) < self.n_step and not done:
return
# Calculate n-step return
reward_n = 0
for i in range(self.n_step):
if i >= len(self.n_step_buffer):
break
reward_n += self.gamma**i * self.n_step_buffer[i][2]
# Get state, action from the first experience
state = self.n_step_buffer[0][0]
action = self.n_step_buffer[0][1]
# Get next_state, done from the last experience
next_state = self.n_step_buffer[-1][3]
done = self.n_step_buffer[-1][4]
# Store in replay buffer with max priority
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward_n, next_state, done)
# Set priority to max priority to ensure it gets sampled
self.priorities[self.position] = self.max_priority
# Move position pointer
self.position = (self.position + 1) % self.capacity
# Remove the first experience from n-step buffer
self.n_step_buffer.pop(0)
# If episode is done, clear n-step buffer
if done:
self.n_step_buffer = []
def sample(self, batch_size):
# Calculate sampling probabilities
if len(self.buffer) < self.capacity:
probs = self.priorities[:len(self.buffer)]
else:
probs = self.priorities
# Normalize probabilities
probs = probs ** self.alpha
probs = probs / probs.sum()
# Sample indices based on priorities
indices = torch.multinomial(probs, batch_size, replacement=True)
# Get samples
states = []
actions = []
rewards = []
next_states = []
dones = []
# Calculate importance sampling weights
weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
weights = weights / weights.max()
self.beta = min(1.0, self.beta + self.beta_increment) # Anneal beta
# Get experiences
for idx in indices:
state, action, reward, next_state, done = self.buffer[idx]
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(done)
# Return only the states, actions, rewards, next_states, dones for compatibility with learn function
return states, actions, rewards, next_states, dones
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors):
# Update priority based on TD error
priority = float(abs(td_error) + 1e-5) # Small constant to ensure non-zero priority
self.priorities[idx] = priority
self.max_priority = max(self.max_priority, priority)
def __len__(self):
return len(self.buffer)
def train_price_predictor(model, data_loaders, optimizer, device, epochs=10):
"""
Train the price prediction model using data from multiple timeframes
Args:
model: The EnhancedPricePredictionModel
data_loaders: List of DataLoader objects, one for each timeframe
optimizer: Optimizer for training
device: Device to train on (CPU or GPU)
epochs: Number of training epochs
"""
model.train()
for epoch in range(epochs):
total_loss = 0
num_batches = 0
# Assume all dataloaders have the same length
for batch_idx, batch_data in enumerate(zip(*data_loaders)):
# Each batch_data is a tuple of (inputs, price_targets, extrema_targets, volume_targets) for each timeframe
optimizer.zero_grad()
# Prepare inputs for each timeframe
inputs_list = [data[0].to(device) for data in batch_data]
price_targets = batch_data[0][1].to(device) # Use targets from the first timeframe (e.g., 1m)
extrema_targets = batch_data[0][2].to(device)
volume_targets = batch_data[0][3].to(device)
# Forward pass
price_pred, extrema_logits, volume_pred = model(inputs_list)
# Calculate losses
price_loss = F.mse_loss(price_pred, price_targets)
extrema_loss = F.binary_cross_entropy_with_logits(extrema_logits, extrema_targets)
volume_loss = F.mse_loss(volume_pred, volume_targets)
# Combined loss with weighting
loss = price_loss + 0.5 * extrema_loss + 0.3 * volume_loss
# Backward pass
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}")
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.6f}")
# Learning rate scheduling
if epoch > 0 and epoch % 5 == 0:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.9
return model
def prepare_multi_timeframe_data(exchange, timeframes=['1m', '15m', '1h'], lookback=30):
"""
Prepare data from multiple timeframes for training
Args:
exchange: Exchange object to fetch data from
timeframes: List of timeframes to fetch
lookback: Number of candles to look back
Returns:
List of DataLoader objects, one for each timeframe
"""
data_loaders = []
for timeframe in timeframes:
# Fetch historical data for this timeframe
candles = exchange.fetch_ohlcv(timeframe=timeframe, limit=1000)
# Prepare inputs and targets
inputs = []
price_targets = []
extrema_targets = []
volume_targets = []
for i in range(lookback, len(candles) - 5):
# Input: lookback candles (price and volume)
input_data = torch.tensor([
[candle[4], candle[5]] for candle in candles[i-lookback:i]
], dtype=torch.float32)
# Target: next 5 candles (price)
price_target = torch.tensor([
candle[4] for candle in candles[i:i+5]
], dtype=torch.float32)
# Target: extrema points in next 5 candles
extrema_target = torch.zeros(10, dtype=torch.float32) # 5 time steps, 2 classes each
for j in range(5):
# Simple extrema detection for demonstration
if j > 0 and j < 4:
# Local high
if candles[i+j][2] > candles[i+j-1][2] and candles[i+j][2] > candles[i+j+1][2]:
extrema_target[j*2] = 1.0
# Local low
if candles[i+j][3] < candles[i+j-1][3] and candles[i+j][3] < candles[i+j+1][3]:
extrema_target[j*2+1] = 1.0
# Target: volume for next 5 candles
volume_target = torch.tensor([
candle[5] for candle in candles[i:i+5]
], dtype=torch.float32)
inputs.append(input_data)
price_targets.append(price_target)
extrema_targets.append(extrema_target)
volume_targets.append(volume_target)
# Create dataset and dataloader
dataset = torch.utils.data.TensorDataset(
torch.stack(inputs),
torch.stack(price_targets),
torch.stack(extrema_targets),
torch.stack(volume_targets)
)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=32, shuffle=True
)
data_loaders.append(data_loader)
return data_loaders

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,806 @@
import numpy as np
import pandas as pd
import os
import random
from datetime import datetime, timedelta
class ExchangeSimulator:
"""
A simple exchange simulator that generates realistic market data
for testing trading algorithms without connecting to a real exchange.
"""
def __init__(self, symbol="BTC/USDT", seed=42):
"""
Initialize the exchange simulator
Args:
symbol: Trading pair symbol
seed: Random seed for reproducibility
"""
self.symbol = symbol
self.seed = seed
np.random.seed(seed)
random.seed(seed)
# Initialize data storage
self.data = {}
self.current_timestamp = datetime.now()
# Generate initial data for different timeframes
self.timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
self.timeframe_minutes = {
'1m': 1,
'5m': 5,
'15m': 15,
'30m': 30,
'1h': 60,
'4h': 240,
'1d': 1440
}
# Generate initial price around $50,000 (for BTC/USDT)
self.base_price = 50000.0
# Generate data for each timeframe
for tf in self.timeframes:
self._generate_initial_data(tf)
def _generate_initial_data(self, timeframe, num_candles=1000):
"""
Generate initial historical data for a specific timeframe
Args:
timeframe: Timeframe to generate data for
num_candles: Number of candles to generate
"""
# Calculate time delta for this timeframe
minutes = self.timeframe_minutes[timeframe]
# Generate timestamps
end_time = self.current_timestamp
timestamps = [end_time - timedelta(minutes=minutes * i) for i in range(num_candles)]
timestamps.reverse() # Oldest first
# Generate price data with realistic patterns
prices = self._generate_price_series(num_candles)
# Generate volume data with realistic patterns
volumes = self._generate_volume_series(num_candles, timeframe)
# Create OHLCV data
ohlcv_data = []
for i in range(num_candles):
# Calculate OHLC based on close price
close = prices[i]
high = close * (1 + np.random.uniform(0, 0.01))
low = close * (1 - np.random.uniform(0, 0.01))
open_price = prices[i-1] if i > 0 else close * (1 - np.random.uniform(-0.005, 0.005))
# Create candle
candle = [
int(timestamps[i].timestamp() * 1000), # Timestamp in milliseconds
open_price, # Open
high, # High
low, # Low
close, # Close
volumes[i] # Volume
]
ohlcv_data.append(candle)
# Store data
self.data[timeframe] = ohlcv_data
def _generate_price_series(self, length):
"""
Generate a realistic price series with trends, reversals, and volatility
Args:
length: Number of prices to generate
Returns:
List of prices
"""
# Start with base price
prices = [self.base_price]
# Parameters for price generation
trend_strength = 0.001 # Strength of trend
volatility = 0.005 # Daily volatility
mean_reversion = 0.001 # Mean reversion strength
# Generate price series
for i in range(1, length):
# Determine if we're in a trend
if i % 100 == 0:
# Change trend direction every ~100 candles
trend_strength = -trend_strength
# Calculate price change
trend = trend_strength * prices[-1]
random_change = np.random.normal(0, volatility) * prices[-1]
mean_reversion_change = mean_reversion * (self.base_price - prices[-1])
# Calculate new price
new_price = prices[-1] + trend + random_change + mean_reversion_change
# Ensure price doesn't go negative
new_price = max(new_price, prices[-1] * 0.9)
prices.append(new_price)
return prices
def _generate_volume_series(self, length, timeframe):
"""
Generate a realistic volume series with patterns
Args:
length: Number of volumes to generate
timeframe: Timeframe for volume scaling
Returns:
List of volumes
"""
# Base volume depends on timeframe
base_volume = {
'1m': 10,
'5m': 50,
'15m': 150,
'30m': 300,
'1h': 600,
'4h': 2400,
'1d': 10000
}[timeframe]
# Generate volume series
volumes = []
for i in range(length):
# Volume tends to be higher at trend reversals and during volatile periods
cycle_factor = 1 + 0.5 * np.sin(i / 20) # Cyclical pattern
random_factor = np.random.lognormal(0, 0.5) # Random spikes
# Calculate volume
volume = base_volume * cycle_factor * random_factor
# Add some volume spikes
if random.random() < 0.05: # 5% chance of volume spike
volume *= random.uniform(2, 5)
volumes.append(volume)
return volumes
def fetch_ohlcv(self, timeframe='1m', limit=100, since=None):
"""
Fetch OHLCV data for a specific timeframe
Args:
timeframe: Timeframe to fetch data for
limit: Number of candles to fetch
since: Timestamp to fetch data since (not used in simulator)
Returns:
List of OHLCV candles
"""
# Ensure timeframe exists
if timeframe not in self.data:
if timeframe in self.timeframe_minutes:
self._generate_initial_data(timeframe)
else:
# Default to 1m if timeframe not supported
timeframe = '1m'
# Get data
data = self.data[timeframe]
# Return limited data
return data[-limit:]
def update(self):
"""
Update the exchange data by generating a new candle for each timeframe
"""
# Update current timestamp
self.current_timestamp = datetime.now()
# Update each timeframe
for tf in self.timeframes:
self._add_new_candle(tf)
def _add_new_candle(self, timeframe):
"""
Add a new candle to the specified timeframe
Args:
timeframe: Timeframe to add candle to
"""
# Get existing data
data = self.data[timeframe]
# Get last close price
last_close = data[-1][4]
# Calculate time delta for this timeframe
minutes = self.timeframe_minutes[timeframe]
# Calculate new timestamp
new_timestamp = int((data[-1][0] / 1000 + minutes * 60) * 1000)
# Generate new price with some randomness
price_change = np.random.normal(0, 0.002) * last_close
new_close = last_close + price_change
# Calculate OHLC
new_open = last_close
new_high = max(new_open, new_close) * (1 + np.random.uniform(0, 0.005))
new_low = min(new_open, new_close) * (1 - np.random.uniform(0, 0.005))
# Generate volume
base_volume = data[-1][5]
volume_change = np.random.normal(0, 0.2) * base_volume
new_volume = max(base_volume + volume_change, base_volume * 0.5)
# Create new candle
new_candle = [
new_timestamp,
new_open,
new_high,
new_low,
new_close,
new_volume
]
# Add to data
self.data[timeframe].append(new_candle)
def get_ticker(self, symbol=None):
"""
Get current ticker information
Args:
symbol: Symbol to get ticker for (defaults to initialized symbol)
Returns:
Dictionary with ticker information
"""
if symbol is None:
symbol = self.symbol
# Get latest 1m candle
latest_candle = self.data['1m'][-1]
return {
'symbol': symbol,
'bid': latest_candle[4] * 0.9999, # Slightly below last price
'ask': latest_candle[4] * 1.0001, # Slightly above last price
'last': latest_candle[4],
'high': latest_candle[2],
'low': latest_candle[3],
'volume': latest_candle[5],
'timestamp': latest_candle[0]
}
def create_order(self, symbol, type, side, amount, price=None):
"""
Simulate creating an order
Args:
symbol: Symbol to create order for
type: Order type (limit, market)
side: Order side (buy, sell)
amount: Order amount
price: Order price (for limit orders)
Returns:
Dictionary with order information
"""
# Get current ticker
ticker = self.get_ticker(symbol)
# Determine execution price
if type == 'market':
if side == 'buy':
execution_price = ticker['ask']
else:
execution_price = ticker['bid']
else: # limit order
execution_price = price
# Create order object
order = {
'id': f"order_{int(datetime.now().timestamp() * 1000)}",
'symbol': symbol,
'type': type,
'side': side,
'amount': amount,
'price': execution_price,
'cost': amount * execution_price,
'filled': amount,
'status': 'closed',
'timestamp': int(datetime.now().timestamp() * 1000)
}
return order
def fetch_balance(self):
"""
Fetch account balance (simulated)
Returns:
Dictionary with balance information
"""
return {
'total': {
'USD': 10000.0,
'BTC': 1.0
},
'free': {
'USD': 5000.0,
'BTC': 0.5
},
'used': {
'USD': 5000.0,
'BTC': 0.5
}
}
def reset(self):
"""
Reset the exchange simulator to its initial state
Returns:
Self for method chaining
"""
# Reset timestamp
self.current_timestamp = datetime.now()
# Regenerate data for each timeframe
for tf in self.timeframes:
self._generate_initial_data(tf)
# Reset any internal state
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.stop_loss = 0
self.take_profit = 0
# Reset prediction history if it exists
if hasattr(self, 'prediction_history'):
self.prediction_history = []
return self
def step(self, action):
"""
Take a step in the environment by executing an action
Args:
action: Action to take (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT)
Returns:
next_state: Next state after taking action
reward: Reward received
done: Whether episode is done
info: Additional information
"""
# Get current price
current_price = self.data['1m'][-1][4]
# Initialize info dictionary
info = {
'price': current_price,
'timestamp': self.data['1m'][-1][0],
'trade': None
}
# Process action
if action == 0: # HOLD
pass # No action needed
elif action == 1: # BUY/LONG
if self.position == 'flat':
# Open a new long position
self.position = 'long'
self.entry_price = current_price
self.position_size = 100 # Simplified position sizing
# Set stop loss and take profit levels
self.stop_loss = current_price * 0.99 # 1% stop loss
self.take_profit = current_price * 1.02 # 2% take profit
# Record entry time
self.entry_time = self.data['1m'][-1][0]
# Add to info
info['trade'] = {
'type': 'long',
'entry': current_price,
'entry_time': self.data['1m'][-1][0],
'size': self.position_size,
'stop_loss': self.stop_loss,
'take_profit': self.take_profit
}
elif self.position == 'short':
# Close short position and open long
pnl = self.entry_price - current_price
pnl_percent = pnl / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Add to info
info['trade'] = {
'type': 'short',
'entry': self.entry_price,
'exit': current_price,
'entry_time': self.entry_time,
'exit_time': self.data['1m'][-1][0],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
}
# Open new long position
self.position = 'long'
self.entry_price = current_price
self.position_size = 100 # Simplified position sizing
# Set stop loss and take profit levels
self.stop_loss = current_price * 0.99 # 1% stop loss
self.take_profit = current_price * 1.02 # 2% take profit
# Record entry time
self.entry_time = self.data['1m'][-1][0]
elif action == 2: # SELL/SHORT
if self.position == 'flat':
# Open a new short position
self.position = 'short'
self.entry_price = current_price
self.position_size = 100 # Simplified position sizing
# Set stop loss and take profit levels
self.stop_loss = current_price * 1.01 # 1% stop loss
self.take_profit = current_price * 0.98 # 2% take profit
# Record entry time
self.entry_time = self.data['1m'][-1][0]
# Add to info
info['trade'] = {
'type': 'short',
'entry': current_price,
'entry_time': self.data['1m'][-1][0],
'size': self.position_size,
'stop_loss': self.stop_loss,
'take_profit': self.take_profit
}
elif self.position == 'long':
# Close long position and open short
pnl = current_price - self.entry_price
pnl_percent = pnl / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Add to info
info['trade'] = {
'type': 'long',
'entry': self.entry_price,
'exit': current_price,
'entry_time': self.entry_time,
'exit_time': self.data['1m'][-1][0],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
}
# Open new short position
self.position = 'short'
self.entry_price = current_price
self.position_size = 100 # Simplified position sizing
# Set stop loss and take profit levels
self.stop_loss = current_price * 1.01 # 1% stop loss
self.take_profit = current_price * 0.98 # 2% take profit
# Record entry time
self.entry_time = self.data['1m'][-1][0]
# Generate next candle
self._add_new_candle('1m')
# Check if stop loss or take profit has been hit
self._check_sl_tp(info)
# Validate predictions if available
if hasattr(self, 'prediction_history') and len(self.prediction_history) > 0:
self.validate_predictions(self.data['1m'][-1])
# Prepare next state (simplified)
next_state = self._get_state()
# Calculate reward (simplified)
reward = 0
if info['trade'] is not None and 'pnl_dollar' in info['trade']:
reward = info['trade']['pnl_dollar']
# Check if done (simplified)
done = False
return next_state, reward, done, info
def _get_state(self):
"""
Get the current state of the environment
Returns:
List representing the current state
"""
# Simplified state representation
state = []
# Add price features
for tf in ['1m', '5m', '15m']:
if tf in self.data:
# Get last 10 candles
candles = self.data[tf][-10:]
# Extract close prices
prices = [candle[4] for candle in candles]
# Calculate price changes
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
# Add to state
state.extend(price_changes)
# Add current price relative to SMA
sma_5 = sum(prices[-5:]) / 5
sma_10 = sum(prices) / 10
state.append(prices[-1] / sma_5 - 1)
state.append(prices[-1] / sma_10 - 1)
# Pad state to 100 dimensions
while len(state) < 100:
state.append(0)
# Ensure state has exactly 100 dimensions
if len(state) > 100:
state = state[:100]
return state
def _check_sl_tp(self, info):
"""
Check if stop loss or take profit has been hit
Args:
info: Info dictionary to update
"""
if self.position == 'flat':
return
# Get current price
current_price = self.data['1m'][-1][4]
if self.position == 'long':
# Check stop loss
if current_price <= self.stop_loss:
# Stop loss hit
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Add to info
info['trade'] = {
'type': 'long',
'entry': self.entry_price,
'exit': self.stop_loss,
'entry_time': self.entry_time,
'exit_time': self.data['1m'][-1][0],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'reason': 'stop_loss',
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
}
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif current_price >= self.take_profit:
# Take profit hit
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Add to info
info['trade'] = {
'type': 'long',
'entry': self.entry_price,
'exit': self.take_profit,
'entry_time': self.entry_time,
'exit_time': self.data['1m'][-1][0],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'reason': 'take_profit',
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
}
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
elif self.position == 'short':
# Check stop loss
if current_price >= self.stop_loss:
# Stop loss hit
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Add to info
info['trade'] = {
'type': 'short',
'entry': self.entry_price,
'exit': self.stop_loss,
'entry_time': self.entry_time,
'exit_time': self.data['1m'][-1][0],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'reason': 'stop_loss',
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
}
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
# Check take profit
elif current_price <= self.take_profit:
# Take profit hit
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Add to info
info['trade'] = {
'type': 'short',
'entry': self.entry_price,
'exit': self.take_profit,
'entry_time': self.entry_time,
'exit_time': self.data['1m'][-1][0],
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'reason': 'take_profit',
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
}
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
def validate_predictions(self, new_candle):
"""
Validate previous extrema predictions against new candle data
Args:
new_candle: New candle data to validate against
"""
if not hasattr(self, 'prediction_history') or not self.prediction_history:
return
# Extract candle data
timestamp = new_candle[0]
high_price = new_candle[2]
low_price = new_candle[3]
# Track validation metrics
validated_count = 0
correct_count = 0
# Check each prediction that hasn't been validated yet
for pred in self.prediction_history:
if pred.get('validated', False):
continue
# Check if this prediction's time has come (or passed)
if 'predicted_timestamp' in pred and timestamp >= pred['predicted_timestamp']:
pred['validated'] = True
validated_count += 1
# Check if prediction was correct
if pred['type'] == 'low':
# A low prediction is correct if price went within 0.5% of predicted low
price_diff_percent = abs(low_price - pred['price']) / pred['price'] * 100
pred['actual_price'] = low_price
pred['price_diff_percent'] = price_diff_percent
# Consider correct if within 0.5% or price went lower than predicted
was_correct = price_diff_percent < 0.5 or low_price <= pred['price']
pred['was_correct'] = was_correct
if was_correct:
correct_count += 1
elif pred['type'] == 'high':
# A high prediction is correct if price went within 0.5% of predicted high
price_diff_percent = abs(high_price - pred['price']) / pred['price'] * 100
pred['actual_price'] = high_price
pred['price_diff_percent'] = price_diff_percent
# Consider correct if within 0.5% or price went higher than predicted
was_correct = price_diff_percent < 0.5 or high_price >= pred['price']
pred['was_correct'] = was_correct
if was_correct:
correct_count += 1
# Return validation metrics
if validated_count > 0:
return {
'validated_count': validated_count,
'correct_count': correct_count,
'accuracy': correct_count / validated_count
}
return None
def calculate_pnl(self):
"""
Calculate the current profit/loss of the open position
Returns:
float: Current PnL in dollars, 0 if no position is open
"""
if self.position == 'flat':
return 0.0
current_price = self.data['1m'][-1][4]
if self.position == 'long':
pnl_percent = (current_price - self.entry_price) / self.entry_price * 100
elif self.position == 'short':
pnl_percent = (self.entry_price - current_price) / self.entry_price * 100
else:
return 0.0
pnl_dollar = pnl_percent / 100 * self.position_size
return pnl_dollar
# Example usage
if __name__ == "__main__":
# Create exchange simulator
exchange = ExchangeSimulator()
# Fetch some data
ohlcv = exchange.fetch_ohlcv(timeframe='1h', limit=10)
print("OHLCV data (1h timeframe):")
for candle in ohlcv[-5:]:
timestamp = datetime.fromtimestamp(candle[0] / 1000)
print(f"{timestamp}: Open={candle[1]:.2f}, High={candle[2]:.2f}, Low={candle[3]:.2f}, Close={candle[4]:.2f}, Volume={candle[5]:.2f}")
# Get current ticker
ticker = exchange.get_ticker()
print(f"\nCurrent ticker: {ticker['last']:.2f}")
# Create a market buy order
order = exchange.create_order("BTC/USDT", "market", "buy", 0.1)
print(f"\nCreated order: {order}")
# Update the exchange (simulate time passing)
exchange.update()
# Get updated ticker
updated_ticker = exchange.get_ticker()
print(f"\nUpdated ticker: {updated_ticker['last']:.2f}")

View File

@ -0,0 +1,42 @@
def fix_indentation():
with open('main.py', 'r') as f:
lines = f.readlines()
# Fix indentation for the problematic sections
fixed_lines = []
# Find the try block that starts at line 1693
try_start_line = 1693
try_block_found = False
in_try_block = False
for i, line in enumerate(lines):
# Check if we're at the try statement
if i+1 == try_start_line and 'try:' in line:
try_block_found = True
in_try_block = True
fixed_lines.append(line)
# Fix the indentation of the experiences line
elif i+1 == 1695 and line.strip().startswith('experiences = self.memory.sample(BATCH_SIZE)'):
# Add proper indentation (4 spaces)
fixed_lines.append(' ' + line.lstrip())
# Check if we're at the end of the try block without an except
elif try_block_found and in_try_block and i+1 > try_start_line and line.strip() and not line.startswith(' '):
# We've reached the end of the try block without an except, add one
fixed_lines.append(' except Exception as e:\n')
fixed_lines.append(' logger.error(f"Error during learning: {e}")\n')
fixed_lines.append(' logger.error(f"Traceback: {traceback.format_exc()}")\n')
fixed_lines.append(' return None\n\n')
in_try_block = False
fixed_lines.append(line)
else:
fixed_lines.append(line)
# Write the fixed content back to the file
with open('main.py', 'w') as f:
f.writelines(fixed_lines)
print("Indentation fixed!")
if __name__ == "__main__":
fix_indentation()

View File

@ -0,0 +1,22 @@
import re
def fix_try_blocks():
with open('main.py', 'r') as f:
content = f.read()
# Find all try blocks without except or finally
pattern = r'(\s+)try:\s*\n((?:\1\s+.*\n)+?)(?!\1\s*except|\1\s*finally)'
# Replace with try-except blocks
fixed_content = re.sub(pattern,
r'\1try:\n\2\1except Exception as e:\n\1 logger.error(f"Error: {e}")\n\1 logger.error(f"Traceback: {traceback.format_exc()}")\n\1 return None\n\n',
content)
# Write the fixed content back to the file
with open('main.py', 'w') as f:
f.write(fixed_content)
print("Try blocks fixed!")
if __name__ == "__main__":
fix_try_blocks()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
import asyncio
from exchange_simulator import ExchangeSimulator
import logging
# Set up logging
logger = logging.getLogger(__name__)
async def main():
"""
Main function to run the training process.
"""
# Initialize exchange simulator
exchange = ExchangeSimulator()
# Train agent
print("Starting training process...")
# Add your training code here
print("Training complete!")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program terminated by user")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
import asyncio
import logging
from exchange_simulator import ExchangeSimulator
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def main():
"""
Main function to run the training process.
"""
# Initialize exchange simulator
exchange = ExchangeSimulator()
# Train agent
print("Starting training process...")
# Add your training code here
print("Training complete!")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program terminated by user")
except Exception as e:
logger.error(f"Error running main: {e}")
import traceback
logger.error(traceback.format_exc())

View File

@ -0,0 +1,304 @@
import os
import json
import logging
import traceback
import numpy as np
from dotenv import load_dotenv
from mexc_api.spot import Spot
from mexc_api.common.enums import Side, OrderType
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('mexc_trading')
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
class MexcTradingClient:
"""Client for executing trades on MEXC exchange using the official API"""
def __init__(self, api_key=None, api_secret=None, symbol="ETH/USDT", leverage=1):
"""Initialize the MEXC trading client"""
self.api_key = api_key or MEXC_API_KEY
self.api_secret = api_secret or MEXC_SECRET_KEY
# Ensure API keys are not None
if not self.api_key or not self.api_secret:
logger.warning("API keys not provided. Using empty strings for public endpoints only.")
self.api_key = ""
self.api_secret = ""
self.symbol = symbol
self.formatted_symbol = symbol.replace('/', '') # MEXC requires no slash
self.leverage = leverage
self.client = None
self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0
self.entry_price = 0
self.stop_loss = 0
self.take_profit = 0
self.open_orders = []
self.order_history = []
self.trades = []
self.win_count = 0
self.loss_count = 0
# Initialize the MEXC API client
self.initialize_client()
def initialize_client(self):
"""Initialize the MEXC API client using the API"""
try:
self.client = Spot(self.api_key, self.api_secret)
# Test connection
server_time = self.client.market.server_time()
logger.info(f"MEXC API client initialized successfully. Server time: {server_time}")
return True
except Exception as e:
logger.error(f"Failed to initialize MEXC API client: {e}")
logger.error(traceback.format_exc())
return False
async def fetch_account_balance(self):
"""Fetch account balance from MEXC API"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot fetch account balance.")
return 0
account_info = self.client.account.account_info()
if 'balances' in account_info:
# Find USDT balance
for asset in account_info['balances']:
if asset['asset'] == 'USDT':
return float(asset['free'])
logger.warning("Could not find USDT balance")
return 0
except Exception as e:
logger.error(f"Error fetching account balance: {e}")
return 0
async def fetch_open_positions(self):
"""Fetch open positions from MEXC API"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot fetch open positions.")
return []
# Fetch open orders
open_orders = self.client.account.open_orders(self.formatted_symbol)
return open_orders
except Exception as e:
logger.error(f"Error fetching open positions: {e}")
return []
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
"""Open a new position using MEXC API"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot open position.")
return False
# Calculate quantity based on size and price
quantity = size / entry_price
# Round quantity to appropriate precision
quantity = round(quantity, 4) # Adjust precision as needed for your asset
# Determine order side
side = Side.BUY if position_type == 'long' else Side.SELL
logger.info(f"Opening {position_type} position: {quantity} {self.symbol} at market price")
# Place market order
order_result = self.client.account.new_order(
self.formatted_symbol,
side,
OrderType.MARKET,
str(quantity)
)
logger.info(f"Market order result: {order_result}")
# Check if order was filled
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
# Get actual entry price
actual_entry_price = float(order_result.get('price', entry_price))
# Place stop loss order
sl_side = Side.SELL if position_type == 'long' else Side.BUY
sl_order = self.client.account.new_order(
self.formatted_symbol,
sl_side,
OrderType.STOP_LOSS_LIMIT,
str(quantity),
price=str(stop_loss),
stop_price=str(stop_loss),
time_in_force="GTC"
)
logger.info(f"Stop loss order placed: {sl_order}")
# Place take profit order
tp_side = Side.SELL if position_type == 'long' else Side.BUY
tp_order = self.client.account.new_order(
self.formatted_symbol,
tp_side,
OrderType.TAKE_PROFIT_LIMIT,
str(quantity),
price=str(take_profit),
stop_price=str(take_profit),
time_in_force="GTC"
)
logger.info(f"Take profit order placed: {tp_order}")
# Update local state
self.position = position_type
self.position_size = size
self.entry_price = actual_entry_price
self.stop_loss = stop_loss
self.take_profit = take_profit
# Track orders
self.open_orders.extend([sl_order, tp_order])
self.order_history.append(order_result)
logger.info(f"Successfully opened {position_type} position at {actual_entry_price}")
return True
else:
logger.error(f"Failed to open position: {order_result}")
return False
except Exception as e:
logger.error(f"Error opening position: {e}")
logger.error(traceback.format_exc())
return False
async def close_position(self, reason="manual_close"):
"""Close an existing position"""
if self.position == 'flat':
logger.info("No position to close")
return True
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot close position.")
return False
# First, cancel any existing stop loss/take profit orders
try:
self.client.account.cancel_open_orders(self.formatted_symbol)
logger.info("Canceled all open orders")
except Exception as e:
logger.warning(f"Error canceling open orders: {e}")
# Determine order side (opposite of position)
side = Side.SELL if self.position == 'long' else Side.BUY
# Calculate quantity
quantity = self.position_size / self.entry_price
# Round quantity to appropriate precision
quantity = round(quantity, 4) # Adjust precision as needed
logger.info(f"Closing {self.position} position: {quantity} {self.symbol} at market price")
# Execute market order to close position
order_result = self.client.account.new_order(
self.formatted_symbol,
side,
OrderType.MARKET,
str(quantity)
)
logger.info(f"Close order result: {order_result}")
# Check if order was filled
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
# Get actual exit price
exit_price = float(order_result.get('price', 0))
# Calculate PnL
if self.position == 'long':
pnl_percent = (exit_price - self.entry_price) / self.entry_price * 100
else: # short
pnl_percent = (self.entry_price - exit_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size
# Record trade
self.trades.append({
'type': self.position,
'entry': self.entry_price,
'exit': exit_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'reason': reason,
'order_id': order_result.get('orderId')
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Track order history
self.order_history.append(order_result)
logger.info(f"Closed {self.position} position at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
# Reset position
self.position = 'flat'
self.entry_price = 0
self.position_size = 0
self.stop_loss = 0
self.take_profit = 0
return True
else:
logger.error(f"Failed to close position: {order_result}")
return False
except Exception as e:
logger.error(f"Error closing position: {e}")
logger.error(traceback.format_exc())
return False
async def check_order_status(self, order_id):
"""Check the status of a specific order"""
try:
# Check if we have API keys for private endpoints
if not self.api_key or self.api_key == "":
logger.warning("No API keys provided. Cannot check order status.")
return None
order_status = self.client.account.query_order(
self.formatted_symbol,
order_id=order_id
)
return order_status
except Exception as e:
logger.error(f"Error checking order status: {e}")
return None
async def get_market_price(self):
"""Get current market price for the symbol"""
try:
ticker = self.client.market.ticker_price(self.formatted_symbol)
if isinstance(ticker, list) and len(ticker) > 0 and 'price' in ticker[0]:
return float(ticker[0]['price'])
elif isinstance(ticker, dict) and 'price' in ticker:
return float(ticker['price'])
else:
logger.error(f"Unexpected ticker format: {ticker}")
return None
except Exception as e:
logger.error(f"Error getting market price: {e}")
return None

View File

@ -0,0 +1,456 @@
# mexc-api-sdk
MEXC Official Market and trade api sdk, easy to connection and send request to MEXC open api !
## Prerequisites
- To use our SDK you have to install nodejs LTS (https://aws.github.io/jsii/user-guides/lib-user/)
## Installation
1.
```
git clone https://github.com/mxcdevelop/mexc-api-sdk.git
```
2. cd dist/{language} and unzip the file
3. we offer five language : dotnet, go, java, js, python
## Table of APIS
- [Init](#init)
- [Market](#market)
- [Ping](#ping)
- [Check Server Time](#check-server-time)
- [Exchange Information](#exchange-information)
- [Recent Trades List](#recent-trades-list)
- [Order Book](#order-book)
- [Old Trade Lookup](#old-trade-lookup)
- [Aggregate Trades List](#aggregate-trades-list)
- [kline Data](#kline-data)
- [Current Average Price](#current-average-price)
- [24hr Ticker Price Change Statistics](#24hr-ticker-price-change-statistics)
- [Symbol Price Ticker](#symbol-price-ticker)
- [Symbol Order Book Ticker](#symbol-order-book-ticker)
- [Trade](#trade)
- [Test New Order](#test-new-order)
- [New Order](#new-order)
- [cancel-order](#cancel-order)
- [Cancel all Open Orders on a Symbol](#cancel-all-open-orders-on-a-symbol)
- [Query Order](#query-order)
- [Current Open Orders](#current-open-orders)
- [All Orders](#all-orders)
- [Account Information](#account-information)
- [Account Trade List](#account-trade-list)
## Init
```javascript
//Javascript
import * as Mexc from 'mexc-sdk';
const apiKey = 'apiKey'
const apiSecret = 'apiSecret'
const client = new Mexc.Spot(apiKey, apiSecret);
```
```go
// Go
package main
import (
"fmt"
"mexc-sdk/mexcsdk"
)
func main() {
apiKey := "apiKey"
apiSecret := "apiSecret"
spot := mexcsdk.NewSpot(apiKey, apiSecret)
}
```
```python
# python
from mexc_sdk import Spot
spot = Spot(api_key='apiKey', api_secret='apiSecret')
```
```java
// java
import Mexc.Sdk.*;
class MyClass {
public static void main(String[] args) {
String apiKey= "apiKey";
String apiSecret= "apiSecret";
Spot mySpot = new Spot(apiKey, apiSecret);
}
}
```
```C#
// dotnet
using System;
using System.Collections.Generic;
using Mxc.Sdk;
namespace dotnet
{
class Program
{
static void Main(string[] args)
{
string apiKey = "apiKey";
string apiSecret= "apiSecret";
var spot = new Spot(apiKey, apiSecret);
}
}
}
```
## Market
### Ping
```javascript
client.ping()
```
### Check Server Time
```javascript
client.time()
```
### Exchange Information
```javascript
client.exchangeInfo(options: any)
options:{symbol, symbols}
/**
* choose one parameter
*
* symbol :
* example "BNBBTC";
*
* symbols :
* array of symbol
* example ["BTCUSDT","BNBBTC"];
*
*/
```
### Recent Trades List
```javascript
client.trades(symbol: string, options: any = { limit: 500 })
options:{limit}
/**
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### Order Book
```javascript
client.depth(symbol: string, options: any = { limit: 100 })
options:{limit}
/**
* limit :
* Number of returned data
* Default 100;
* max 5000;
* Valid:[5, 10, 20, 50, 100, 500, 1000, 5000]
*
*/
```
### Old Trade Lookup
```javascript
client.historicalTrades(symbol: string, options: any = { limit: 500 })
options:{limit, fromId}
/**
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
* fromId:
* Trade id to fetch from. Default gets most recent trades
*
*/
```
### Aggregate Trades List
```javascript
client.aggTrades(symbol: string, options: any = { limit: 500 })
options:{fromId, startTime, endTime, limit}
/**
*
* fromId :
* id to get aggregate trades from INCLUSIVE
*
* startTime:
* start at
*
* endTime:
* end at
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### kline Data
```javascript
client.klines(symbol: string, interval: string, options: any = { limit: 500 })
options:{ startTime, endTime, limit}
/**
*
* interval :
* m :minute;
* h :Hour;
* d :day;
* w :week;
* M :month
* example : "1m"
*
* startTime :
* start at
*
* endTime :
* end at
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### Current Average Price
```javascript
client.avgPrice(symbol: string)
```
### 24hr Ticker Price Change Statistics
```javascript
client.ticker24hr(symbol?: string)
```
### Symbol Price Ticker
```javascript
client.tickerPrice(symbol?: string)
```
### Symbol Order Book Ticker
```javascript
client.bookTicker(symbol?: string)
```
## Trade
### Test New Order
```javascript
client.newOrderTest(symbol: string, side: string, orderType: string, options: any = {})
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
/**
*
* side:
* Order side
* ENUM:
* BUY
* SELL
*
* orderType:
* Order type
* ENUM:
* LIMIT
* MARKET
* STOP_LOSS
* STOP_LOSS_LIMIT
* TAKE_PROFIT
* TAKE_PROFIT_LIMIT
* LIMIT_MAKER
*
* timeInForce :
* How long an order will be active before expiration.
* GTC: Active unless the order is canceled
* IOC: Order will try to fill the order as much as it can before the order expires
* FOK: Active unless the full order cannot be filled upon execution.
*
* quantity :
* target quantity
*
* quoteOrderQty :
* Specify the total spent or received
*
* price :
* target price
*
* newClientOrderId :
* A unique id among open orders. Automatically generated if not sent
*
* stopPrice :
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
*
* icebergQty :
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
*
* newOrderRespType :
* Set the response JSON. ACK, RESULT, or FULL;
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
*
* recvWindow :
* Delay accept time
* The value cannot be greater than 60000
* defaults: 5000
*
*/
```
### New Order
```javascript
client.newOrder(symbol: string, side: string, orderType: string, options: any = {})
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
/**
*
* side:
* Order side
* ENUM:
* BUY
* SELL
*
* orderType:
* Order type
* ENUM:
* LIMIT
* MARKET
* STOP_LOSS
* STOP_LOSS_LIMIT
* TAKE_PROFIT
* TAKE_PROFIT_LIMIT
* LIMIT_MAKER
*
* timeInForce :
* How long an order will be active before expiration.
* GTC: Active unless the order is canceled
* IOC: Order will try to fill the order as much as it can before the order expires
* FOK: Active unless the full order cannot be filled upon execution.
*
* quantity :
* target quantity
*
* quoteOrderQty :
* Specify the total spent or received
*
* price :
* target price
*
* newClientOrderId :
* A unique id among open orders. Automatically generated if not sent
*
* stopPrice :
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
*
* icebergQty :
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
*
* newOrderRespType :
* Set the response JSON. ACK, RESULT, or FULL;
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
*
* recvWindow :
* Delay accept time
* The value cannot be greater than 60000
* defaults: 5000
*
*/
```
### cancel-order
```javascript
client.cancelOrder(symbol: string, options:any = {})
options:{ orderId, origClientOrderId, newClientOrderId}
/**
*
* Either orderId or origClientOrderId must be sent
*
* orderId:
* target orderId
*
* origClientOrderId:
* target origClientOrderId
*
* newClientOrderId:
* Used to uniquely identify this cancel. Automatically generated by default.
*
*/
```
### Cancel all Open Orders on a Symbol
```javascript
client.cancelOpenOrders(symbol: string)
```
### Query Order
```javascript
client.queryOrder(symbol: string, options:any = {})
options:{ orderId, origClientOrderId}
/**
*
* Either orderId or origClientOrderId must be sent
*
* orderId:
* target orderId
*
* origClientOrderId:
* target origClientOrderId
*
*/
```
### Current Open Orders
```javascript
client.openOrders(symbol: string)
```
### All Orders
```javascript
client.allOrders(symbol: string, options: any = { limit: 500 })
options:{ orderId, startTime, endTime, limit}
/**
*
* orderId:
* target orderId
*
* startTime:
* start at
*
* endTime:
* end at
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```
### Account Information
```javascript
client.accountInfo()
```
### Account Trade List
```javascript
client.accountTradeList(symbol: string, options:any = { limit: 500 })
options:{ orderId, startTime, endTime, fromId, limit}
/**
*
* orderId:
* target orderId
*
* startTime:
* start at
*
* endTime:
* end at
*
* fromId:
* TradeId to fetch from. Default gets most recent trades
*
* limit :
* Number of returned data
* Default 500;
* max 1000;
*
*/
```

View File

@ -1,182 +1,91 @@
# Crypto Trading Bot with Reinforcement Learning # Crypto Trading Bot with MEXC API Integration
An automated cryptocurrency trading bot that uses Deep Q-Learning (DQN) to trade ETH/USDT on the MEXC exchange. The bot features a sophisticated neural network architecture with LSTM layers and attention mechanisms for better pattern recognition. This is an AI-powered cryptocurrency trading bot that can run in both simulation (demo) mode and live trading mode using the MEXC exchange API.
## Features ## Features
- Deep Q-Learning with experience replay - Deep Reinforcement Learning agent for trading decisions
- LSTM layers for sequential data processing - Technical indicators and price prediction
- Multi-head attention mechanism - Live trading integration with MEXC exchange via mexc-api
- Dueling DQN architecture - Demo mode for testing without real trades
- Real-time trading capabilities - Real-time data streaming via websockets
- TensorBoard integration for monitoring - Performance tracking and visualization
- Comprehensive technical indicators
- Demo and live trading modes
- Automatic model checkpointing
## Prerequisites ## Setup
- Python 3.8+ 1. Clone the repository
- MEXC Exchange API credentials 2. Install dependencies:
- GPU recommended but not required ```
pip install -r requirements.txt
```
3. Create a `.env` file in the root directory with your MEXC API keys:
```
MEXC_API_KEY=your_api_key_here
MEXC_SECRET_KEY=your_secret_key_here
```
## Installation
1. Clone the repository:
```bash
git clone https://github.com/yourusername/crypto-trading-bot.git
cd crypto-trading-bot
```
2. Create a virtual environment:
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
3. Install dependencies:
```bash
pip install -r requirements.txt
```
4. Create a `.env` file in the project root with your MEXC API credentials:
```bash
MEXC_API_KEY=your_api_key
MEXC_API_SECRET=your_api_secret
```
## Usage ## Usage
The bot can be run in three modes: The bot can be run in three different modes:
### Training Mode ### Training Mode
```bash Train the agent on historical data:
python main.py --mode train --episodes 1000
```
python main.py --mode train --episodes 100
``` ```
### Evaluation Mode ### Evaluation Mode
```bash Evaluate the trained agent on historical data:
python main.py --mode eval --episodes 10
```
python main.py --mode evaluate
``` ```
### Live Trading Mode ### Live Trading Mode
```bash Run the bot in live trading mode:
# Demo mode (simulated trading with real market data)
python main.py --mode live --demo
# Real trading (actual trades on MEXC) ```
python main.py --mode live python main.py --mode live
``` ```
Demo mode simulates trading using real-time market data but does not execute actual trades. It still: To run in demo mode (no real trades):
- Logs all trading decisions and performance metrics
- Updates the model based on market data (if in training mode)
- Displays real-time analytics and position information
- Calculates theoretical profits/losses
- Saves performance data to TensorBoard
This makes it perfect for testing strategies without financial risk. ```
python main.py --mode live --demo
```
## Live Trading Implementation
The bot uses the mexc-api package to execute trades on the MEXC exchange. The implementation includes:
- Market order execution for opening and closing positions
- Stop loss and take profit orders
- Real-time balance updates
- Trade history tracking
## Configuration ## Configuration
Key parameters can be adjusted in `main.py`: You can adjust the following parameters in `main.py`:
- `INITIAL_BALANCE`: Starting balance for training/demo - `INITIAL_BALANCE`: Starting balance for simulation
- `MAX_LEVERAGE`: Maximum leverage for trades - `MAX_LEVERAGE`: Leverage to use for trading
- `STOP_LOSS_PERCENT`: Stop loss percentage - `STOP_LOSS_PERCENT`: Default stop loss percentage
- `TAKE_PROFIT_PERCENT`: Take profit percentage - `TAKE_PROFIT_PERCENT`: Default take profit percentage
- `BATCH_SIZE`: Training batch size
- `LEARNING_RATE`: Model learning rate
- `STATE_SIZE`: Size of the state representation
## Model Architecture ## Architecture
The DQN model includes:
- Input layer with technical indicators
- LSTM layers for temporal pattern recognition
- Multi-head attention mechanism
- Dueling architecture for better Q-value estimation
- Batch normalization for stable training
## Monitoring
Training progress can be monitored using TensorBoard:
Training progress is logged to TensorBoard:
```bash
tensorboard --logdir=logs
```
This will show:
- Training rewards
- Account balance
- Win rate
- Loss metrics
## Trading Strategy
The bot makes decisions based on:
- Price action
- Technical indicators (RSI, MACD, Bollinger Bands, etc.)
- Historical patterns through LSTM
- Risk management with stop-loss and take-profit
## Safety Features
- Demo mode for safe testing
- Automatic stop-loss
- Position size limits
- Error handling for API calls
- Logging of all actions
## Directory Structure
├── main.py # Main bot implementation
├── requirements.txt # Project dependencies
├── .env # API credentials
├── models/ # Saved model checkpoints
├── runs/ # TensorBoard logs
└── trading_bot.log # Activity logs
- `main.py`: Main entry point and trading logic
- `mexc_trading.py`: MEXC API integration for live trading using mexc-api
- `models/`: Directory for saved model weights
## Warning ## Warning
Cryptocurrency trading carries significant risks. This bot is for educational purposes and should not be used with real money without thorough testing and understanding of the risks involved. Trading cryptocurrencies involves significant risk. This bot is provided for educational purposes only. Use at your own risk.
## License ## License
[MIT License](LICENSE) MIT
The main changes I made:
Fixed code block formatting by adding proper language identifiers
Added missing closing code blocks
Properly formatted directory structure
Added complete sections that were cut off in the original
Ensured consistent formatting throughout the document
Added proper bash syntax highlighting for command examples
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
# Edits/improvements
Fixes the shape mismatch by ensuring the state vector is exactly STATE_SIZE elements
Adds robust error handling in the model's forward pass to handle mismatched inputs
Adds a transformer encoder for more sophisticated pattern recognition
Provides an expand_model method to increase model capacity while preserving learned weights
Adds detailed logging about model size and shape mismatches
The model now has:
Configurable hidden layer sizes
Transformer layers for complex pattern recognition
LSTM layers for temporal patterns
Attention mechanisms for focusing on important features
Dueling architecture for better Q-value estimation
With hidden_size=256, this model has about 1-2 million parameters. By increasing hidden_size to 512 or 1024, you can easily scale to 5-20 million parameters. For even larger models (billions of parameters), you would need to implement a more distributed architecture with multiple GPUs, which would require significant changes to the training loop.

View File

@ -1,9 +1,11 @@
numpy>=1.21.0 numpy>=1.20.0
pandas>=1.3.0 pandas>=1.3.0
matplotlib>=3.4.0 matplotlib>=3.4.0
torch>=1.9.0 torch>=1.9.0
python-dotenv>=0.19.0 scikit-learn>=0.24.0
ccxt>=2.0.0 ccxt>=2.0.0
python-dotenv>=0.19.0
websockets>=10.0 websockets>=10.0
tensorboard>=2.6.0 tensorboard>=2.7.0
scikit-learn mexc-api>=1.0.0
asyncio>=3.4.3

View File

@ -0,0 +1,316 @@
import argparse
import os
import sys
import asyncio
import torch
from enhanced_training import enhanced_train_agent
from exchange_simulator import ExchangeSimulator
# Fix for Windows asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
parser.add_argument('--mode', type=str, default='train', choices=['train', 'continuous', 'evaluate', 'live', 'demo'],
help='Mode to run the trading bot in')
parser.add_argument('--episodes', type=int, default=100,
help='Number of episodes to train for')
parser.add_argument('--start-episode', type=int, default=0,
help='Episode to start from for continuous training')
parser.add_argument('--device', type=str, default='auto',
help='Device to train on (auto, cuda, cpu)')
parser.add_argument('--timeframes', type=str, default='1m,15m,1h',
help='Comma-separated list of timeframes to use')
parser.add_argument('--refresh-data', action='store_true',
help='Refresh data before training')
parser.add_argument('--verbose', action='store_true',
help='Enable verbose logging')
args = parser.parse_args()
# Set device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# Parse timeframes
timeframes = args.timeframes.split(',')
print(f"Using timeframes: {timeframes}")
# Initialize exchange simulator
exchange = ExchangeSimulator()
# Run in specified mode
if args.mode == 'train':
# Train from scratch
print(f"Training for {args.episodes} episodes...")
enhanced_train_agent(
exchange=exchange,
num_episodes=args.episodes,
continuous=False,
start_episode=0,
verbose=args.verbose
)
elif args.mode == 'continuous':
# Continue training from checkpoint
print(f"Continuing training from episode {args.start_episode} for {args.episodes} episodes...")
enhanced_train_agent(
exchange=exchange,
num_episodes=args.episodes,
continuous=True,
start_episode=args.start_episode,
verbose=args.verbose
)
elif args.mode == 'evaluate':
# Evaluate the model
print("Evaluating model...")
evaluate_model(exchange, device)
elif args.mode == 'live' or args.mode == 'demo':
# Run in live or demo mode
is_demo = args.mode == 'demo'
print(f"Running in {'demo' if is_demo else 'live'} mode...")
run_live(exchange, device, is_demo=is_demo)
print("Done!")
def evaluate_model(exchange, device):
"""
Evaluate the trained model
Args:
exchange: Exchange simulator
device: Device to run on
"""
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
import torch
import numpy as np
# Load the best model
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
if not os.path.exists(model_path):
model_path = 'models/enhanced_trading_agent_latest.pt'
if not os.path.exists(model_path):
print("No model found to evaluate!")
return
print(f"Loading model from {model_path}")
checkpoint = torch.load(model_path, map_location=device)
# Initialize models
state_dim = 100
action_dim = 3
timeframes = ['1m', '15m', '1h']
price_model = EnhancedPricePredictionModel(
input_dim=2,
hidden_dim=256,
num_layers=3,
output_dim=5,
num_timeframes=len(timeframes)
).to(device)
dqn_model = EnhancedDQN(
state_dim=state_dim,
action_dim=action_dim,
hidden_dim=512
).to(device)
# Load model weights
price_model.load_state_dict(checkpoint['price_model_state_dict'])
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
# Set models to evaluation mode
price_model.eval()
dqn_model.eval()
# Run evaluation
num_steps = 1000
total_reward = 0
trades = []
# Initialize state
from enhanced_training import initialize_state, step_environment
state = initialize_state(exchange, timeframes)
for step in range(num_steps):
# Select action
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
q_values, _, _ = dqn_model(state_tensor)
action = q_values.argmax().item()
# Execute action
next_state, reward, done, trade_info = step_environment(
exchange, state, action, price_model, timeframes, device
)
# Update state and accumulate reward
state = next_state
total_reward += reward
# Track trade
if trade_info is not None:
trades.append(trade_info)
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, PnL: {trade_info['pnl']:.2f}")
# Update exchange (simulate time passing)
if step % 10 == 0:
exchange.update()
if done:
break
# Calculate metrics
avg_reward = total_reward / num_steps
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
wins = sum(1 for trade in trades if trade['pnl'] > 0)
losses = sum(1 for trade in trades if trade['pnl'] < 0)
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
print("\nEvaluation Results:")
print(f"Average Reward: {avg_reward:.2f}")
print(f"Total PnL: ${total_pnl:.2f}")
print(f"Win Rate: {win_rate:.1f}% ({wins}/{wins+losses})")
def run_live(exchange, device, is_demo=True):
"""
Run the trading bot in live or demo mode
Args:
exchange: Exchange simulator or real exchange
device: Device to run on
is_demo: Whether to run in demo mode (no real trades)
"""
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
import torch
import time
# Load the best model
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
if not os.path.exists(model_path):
model_path = 'models/enhanced_trading_agent_latest.pt'
if not os.path.exists(model_path):
print("No model found to run in live mode!")
return
print(f"Loading model from {model_path}")
checkpoint = torch.load(model_path, map_location=device)
# Initialize models
state_dim = 100
action_dim = 3
timeframes = ['1m', '15m', '1h']
price_model = EnhancedPricePredictionModel(
input_dim=2,
hidden_dim=256,
num_layers=3,
output_dim=5,
num_timeframes=len(timeframes)
).to(device)
dqn_model = EnhancedDQN(
state_dim=state_dim,
action_dim=action_dim,
hidden_dim=512
).to(device)
# Load model weights
price_model.load_state_dict(checkpoint['price_model_state_dict'])
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
# Set models to evaluation mode
price_model.eval()
dqn_model.eval()
# Run live trading
print(f"Running in {'demo' if is_demo else 'live'} mode...")
print("Press Ctrl+C to stop")
# Initialize state
from enhanced_training import initialize_state, step_environment
state = initialize_state(exchange, timeframes)
try:
while True:
# Select action
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
q_values, _, market_regime = dqn_model(state_tensor)
action = q_values.argmax().item()
# Get market regime prediction
regime_probs = torch.softmax(market_regime, dim=1).cpu().numpy()[0]
regime_names = ['Trending', 'Ranging', 'Volatile']
predicted_regime = regime_names[regime_probs.argmax()]
# Get current price
ticker = exchange.get_ticker()
current_price = ticker['last']
# Print state
print(f"\nCurrent price: ${current_price:.2f}")
print(f"Predicted market regime: {predicted_regime} ({regime_probs.max()*100:.1f}% confidence)")
# Execute action
next_state, reward, _, trade_info = step_environment(
exchange, state, action, price_model, timeframes, device
)
# Print action
action_names = ['Hold', 'Buy', 'Sell']
print(f"Action: {action_names[action]}")
if trade_info is not None:
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, Size: {trade_info['size']:.2f}, Entry Quality: {trade_info['entry_quality']:.2f}")
# Execute real trade if not in demo mode
if not is_demo:
if trade_info['action'] == 'buy':
order = exchange.create_order(
symbol="BTC/USDT",
type="market",
side="buy",
amount=trade_info['size'] / current_price
)
print(f"Executed buy order: {order}")
else: # sell
order = exchange.create_order(
symbol="BTC/USDT",
type="market",
side="sell",
amount=trade_info['size'] / current_price
)
print(f"Executed sell order: {order}")
# Update state
state = next_state
# Update exchange (simulate time passing)
exchange.update()
# Wait for next candle
time.sleep(5) # In a real implementation, this would wait for the next candle
except KeyboardInterrupt:
print("\nStopping live trading")
if __name__ == "__main__":
main()

18
crypto/gogo2/run_main.py Normal file
View File

@ -0,0 +1,18 @@
import asyncio
import logging
from enhanced_training import main
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program terminated by user")
except Exception as e:
logger.error(f"Error running main: {e}")
import traceback
logger.error(traceback.format_exc())

185
crypto/gogo2/test_cache.py Normal file
View File

@ -0,0 +1,185 @@
import os
import sys
import json
import logging
import time
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger('cache_test')
# Import our cache implementation
from data_cache import ohlcv_cache
def generate_sample_data(num_candles=100):
"""Generate sample OHLCV data for testing"""
data = []
base_timestamp = int(time.time() * 1000) - (num_candles * 60 * 1000) # Start from num_candles minutes ago
for i in range(num_candles):
timestamp = base_timestamp + (i * 60 * 1000) # Add i minutes
# Generate some random-ish but realistic looking price data
base_price = 1900.0 + (i * 0.5) # Slight uptrend
open_price = base_price - 0.5 + (i % 3)
close_price = base_price + 0.3 + ((i+1) % 4)
high_price = max(open_price, close_price) + 1.0 + (i % 2)
low_price = min(open_price, close_price) - 0.8 - (i % 2)
volume = 10.0 + (i % 10) * 2.0
data.append({
'timestamp': timestamp,
'open': open_price,
'high': high_price,
'low': low_price,
'close': close_price,
'volume': volume
})
return data
def test_cache_save_load():
"""Test saving and loading data from cache"""
logger.info("Testing cache save and load...")
# Generate sample data
data = generate_sample_data(100)
logger.info(f"Generated {len(data)} sample candles")
# Save to cache
symbol = "ETH/USDT"
timeframe = "1m"
success = ohlcv_cache.save(data, symbol, timeframe)
logger.info(f"Saved to cache: {success}")
# Load from cache
cached_data = ohlcv_cache.load(symbol, timeframe)
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
# Verify data integrity
if cached_data:
first_original = data[0]
first_cached = cached_data[0]
logger.info(f"First original candle: {first_original}")
logger.info(f"First cached candle: {first_cached}")
last_original = data[-1]
last_cached = cached_data[-1]
logger.info(f"Last original candle: {last_original}")
logger.info(f"Last cached candle: {last_cached}")
return success and cached_data and len(cached_data) == len(data)
def test_cache_append():
"""Test appending a new candle to cached data"""
logger.info("Testing cache append...")
# Generate sample data
data = generate_sample_data(100)
# Save to cache
symbol = "ETH/USDT"
timeframe = "5m"
success = ohlcv_cache.save(data, symbol, timeframe)
logger.info(f"Saved to cache: {success}")
# Generate a new candle
last_timestamp = data[-1]['timestamp']
new_timestamp = last_timestamp + (5 * 60 * 1000) # 5 minutes later
new_candle = {
'timestamp': new_timestamp,
'open': 1950.0,
'high': 1955.0,
'low': 1948.0,
'close': 1952.0,
'volume': 15.0
}
# Append to cache
success = ohlcv_cache.append(new_candle, symbol, timeframe)
logger.info(f"Appended to cache: {success}")
# Load from cache
cached_data = ohlcv_cache.load(symbol, timeframe)
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
# Verify the new candle was appended
if cached_data:
last_cached = cached_data[-1]
logger.info(f"New candle: {new_candle}")
logger.info(f"Last cached candle: {last_cached}")
return success and cached_data and len(cached_data) == len(data) + 1
def test_cache_dataframe():
"""Test converting cached data to a pandas DataFrame"""
logger.info("Testing cache to DataFrame conversion...")
# Generate sample data
data = generate_sample_data(100)
# Save to cache
symbol = "ETH/USDT"
timeframe = "15m"
success = ohlcv_cache.save(data, symbol, timeframe)
logger.info(f"Saved to cache: {success}")
# Convert to DataFrame
df = ohlcv_cache.to_dataframe(symbol, timeframe)
logger.info(f"Converted to DataFrame with {len(df) if df is not None else 0} rows")
# Display DataFrame info
if df is not None:
logger.info(f"DataFrame columns: {df.columns.tolist()}")
logger.info(f"DataFrame index: {df.index.name}")
logger.info(f"First row: {df.iloc[0].to_dict()}")
logger.info(f"Last row: {df.iloc[-1].to_dict()}")
return success and df is not None and len(df) == len(data)
def main():
"""Run all tests"""
logger.info("Starting cache tests...")
# Run tests
save_load_success = test_cache_save_load()
append_success = test_cache_append()
dataframe_success = test_cache_dataframe()
# Print results
logger.info("Test results:")
logger.info(f" Save/Load: {'PASS' if save_load_success else 'FAIL'}")
logger.info(f" Append: {'PASS' if append_success else 'FAIL'}")
logger.info(f" DataFrame: {'PASS' if dataframe_success else 'FAIL'}")
# Check cache directory contents
cache_dir = ohlcv_cache.cache_dir
logger.info(f"Cache directory: {cache_dir}")
if os.path.exists(cache_dir):
files = os.listdir(cache_dir)
logger.info(f"Cache files: {files}")
# Print file sizes
for file in files:
file_path = os.path.join(cache_dir, file)
size_kb = os.path.getsize(file_path) / 1024
logger.info(f" {file}: {size_kb:.2f} KB")
# Print first few lines of each file
with open(file_path, 'r') as f:
data = json.load(f)
logger.info(f" Metadata: symbol={data.get('symbol')}, timeframe={data.get('timeframe')}, last_updated={datetime.fromtimestamp(data.get('last_updated')).strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f" Candles: {len(data.get('data', []))}")
return save_load_success and append_success and dataframe_success
if __name__ == "__main__":
main()

View File

@ -0,0 +1,23 @@
from mexc_api.spot import Spot
def test_mexc_api():
try:
# Initialize client with empty API keys for public endpoints
client = Spot("", "")
# Test server time endpoint
server_time = client.market.server_time()
print(f"Server time: {server_time}")
# Test ticker price endpoint
ticker = client.market.ticker_price("ETHUSDT")
print(f"ETH/USDT price: {ticker}")
print("MEXC API is working correctly!")
return True
except Exception as e:
print(f"Error testing MEXC API: {e}")
return False
if __name__ == "__main__":
test_mexc_api()

View File

@ -0,0 +1,34 @@
import asyncio
import os
from dotenv import load_dotenv
from mexc_trading import MexcTradingClient
# Load environment variables
load_dotenv()
async def test_trading_client():
"""Test the MexcTradingClient functionality"""
print("Initializing MexcTradingClient...")
client = MexcTradingClient(symbol="ETH/USDT")
# Test getting market price
print("Testing get_market_price...")
price = await client.get_market_price()
print(f"Current ETH/USDT price: {price}")
# If API keys are provided, test account balance
if os.getenv('MEXC_API_KEY') and os.getenv('MEXC_SECRET_KEY'):
print("Testing fetch_account_balance...")
balance = await client.fetch_account_balance()
print(f"Account balance: {balance} USDT")
print("Testing fetch_open_positions...")
positions = await client.fetch_open_positions()
print(f"Open positions: {positions}")
else:
print("No API keys provided. Skipping private endpoint tests.")
print("MexcTradingClient test completed!")
if __name__ == "__main__":
asyncio.run(test_trading_client())

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 60 KiB

After

Width:  |  Height:  |  Size: 307 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

View File

@ -0,0 +1,56 @@
import os
import argparse
import subprocess
import webbrowser
import time
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Visualize TensorBoard logs')
parser.add_argument('--logdir', type=str, default='./logs', help='Directory containing TensorBoard logs')
parser.add_argument('--port', type=int, default=6006, help='Port for TensorBoard server')
args = parser.parse_args()
log_dir = Path(args.logdir)
if not log_dir.exists():
print(f"Log directory {log_dir} does not exist. Creating it...")
log_dir.mkdir(parents=True, exist_ok=True)
# Check if TensorBoard is installed
try:
subprocess.run(['tensorboard', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (subprocess.CalledProcessError, FileNotFoundError):
print("TensorBoard not found. Installing...")
subprocess.run(['pip', 'install', 'tensorboard'], check=True)
# Start TensorBoard server
print(f"Starting TensorBoard server on port {args.port}...")
tensorboard_process = subprocess.Popen(
['tensorboard', '--logdir', str(log_dir), '--port', str(args.port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Wait for TensorBoard to start
time.sleep(3)
# Open browser
url = f"http://localhost:{args.port}"
print(f"Opening TensorBoard in browser: {url}")
webbrowser.open(url)
print("TensorBoard is running. Press Ctrl+C to stop.")
try:
# Keep the script running until interrupted
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Stopping TensorBoard server...")
tensorboard_process.terminate()
tensorboard_process.wait()
print("TensorBoard server stopped.")
if __name__ == "__main__":
main()