Compare commits
3 Commits
march-trad
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
fdde2ff587 | ||
|
690c61f230 | ||
|
0148964409 |
21
.gitignore
vendored
21
.gitignore
vendored
@ -38,6 +38,21 @@ crypto/gogo2/trading_bot.log
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||
*trading_agent_continuous_*.pt
|
||||
*trading_agent_episode_*.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_*.pt
|
||||
crypto/gogo2/visualizations/training_episode_*.png
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_150.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_0.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_10.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_20.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_40.pt
|
||||
crypto/gogo2/models/trading_agent_best_pnl.pt
|
||||
crypto/gogo2/models/trading_agent_best_reward.pt
|
||||
crypto/gogo2/models/trading_agent_best_winrate.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_0.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_50.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_100.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_150.pt
|
||||
crypto/gogo2/models/trading_agent_emergency.pt
|
||||
crypto/gogo2/models/trading_agent_episode_0.pt
|
||||
crypto/gogo2/models/trading_agent_episode_10.pt
|
||||
crypto/gogo2/models/trading_agent_episode_20.pt
|
||||
crypto/gogo2/models/trading_agent_episode_30.pt
|
||||
crypto/gogo2/models/trading_agent_final.pt
|
||||
|
1
crypto/gogo2/.gitattributes
vendored
Normal file
1
crypto/gogo2/.gitattributes
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
1
crypto/gogo2/.gitignore
vendored
Normal file
1
crypto/gogo2/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.pt
|
12
crypto/gogo2/.vscode/launch.json
vendored
12
crypto/gogo2/.vscode/launch.json
vendored
@ -5,8 +5,8 @@
|
||||
"name": "Train Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main_multiu_broken.py",
|
||||
"args": ["--mode", "train", "--episodes", "10000"],
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "train", "--episodes", "100"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
@ -14,7 +14,7 @@
|
||||
"name": "Evaluate Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main_multiu_broken.py",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "eval", "--episodes", "10"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
@ -23,7 +23,7 @@
|
||||
"name": "Live Trading (Demo)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main_multiu_broken.py",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "live", "--demo"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
@ -32,7 +32,7 @@
|
||||
"name": "Live Trading (Real)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main_multiu_broken.py",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "live"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
@ -41,7 +41,7 @@
|
||||
"name": "Continuous Training",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main_multiu_broken.py",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "continuous", "--refresh-data"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
|
@ -1,42 +0,0 @@
|
||||
def step(self, action):
|
||||
"""Take an action in the environment and return the next state, reward, and done flag"""
|
||||
# Store current price before taking action
|
||||
self.current_price = self.data[self.current_step]['close']
|
||||
|
||||
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
||||
if not self.demo and self.trading_client:
|
||||
# Execute real trades in live mode
|
||||
asyncio.create_task(self._execute_live_action(action))
|
||||
|
||||
# Calculate reward (simulation still runs in parallel with live trading)
|
||||
reward, _ = self.calculate_reward(action) # Unpack the tuple here
|
||||
|
||||
# Check for stop loss / take profit hits
|
||||
self.check_sl_tp()
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
done = self.current_step >= len(self.data) - 1
|
||||
|
||||
# Get new state
|
||||
next_state = self.get_state()
|
||||
|
||||
return next_state, reward, done
|
||||
|
||||
def calculate_reward(self, action):
|
||||
"""Calculate the reward for the current action."""
|
||||
# ... (existing code)
|
||||
|
||||
# Combine all reward components
|
||||
reward = pnl_reward + timing_reward + risk_reward + prediction_reward
|
||||
|
||||
# Log components for analysis
|
||||
info = {
|
||||
'pnl_reward': pnl_reward,
|
||||
'timing_reward': timing_reward,
|
||||
'risk_reward': risk_reward,
|
||||
'prediction_reward': prediction_reward,
|
||||
'total_reward': reward
|
||||
}
|
||||
|
||||
return reward # Return only the reward, not the info dictionary
|
@ -1,67 +0,0 @@
|
||||
# Neural Network Architecture Analysis for Trading Bot
|
||||
|
||||
## Overview
|
||||
|
||||
This document provides a comprehensive analysis of the neural network architecture used in our trading bot system. The system consists of two main neural network components:
|
||||
|
||||
1. **Price Prediction Model** - Forecasts future price movements and extrema points
|
||||
2. **DQN (Deep Q-Network)** - Makes trading decisions based on state representations
|
||||
|
||||
## 1. Price Prediction Model
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
PricePredictionModel(nn.Module)
|
||||
├── Input Layer: [batch_size, seq_len, 2] (price, volume)
|
||||
├── LSTM Layers: 2 stacked layers with hidden_size=128
|
||||
├── Attention Mechanism: Self-attention with linear projections
|
||||
├── Linear Layer 1: hidden_size → hidden_size
|
||||
├── ReLU Activation
|
||||
├── Linear Layer 2: hidden_size → output_size (5 future prices)
|
||||
└── Output: [batch_size, output_size]
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
**Inputs:**
|
||||
- `price_history`: Sequence of historical prices [batch_size, seq_len]
|
||||
- `volume_history`: Sequence of historical volumes [batch_size, seq_len]
|
||||
|
||||
**Preprocessing:**
|
||||
- Normalization using MinMaxScaler (0-1 range)
|
||||
- Reshaping to [batch_size, seq_len, 2] (price and volume features)
|
||||
|
||||
**Forward Pass:**
|
||||
1. Input data passes through LSTM layers
|
||||
2. Self-attention mechanism applied to LSTM outputs
|
||||
3. Linear layers process the attended features
|
||||
4. Output represents predicted prices for next 5 candles
|
||||
|
||||
**Outputs:**
|
||||
- `predicted_prices`: Array of 5 future price predictions
|
||||
- `predicted_extrema`: Binary indicators for potential price extrema points
|
||||
|
||||
## 2. DQN (Deep Q-Network)
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
DQN(nn.Module)
|
||||
├── Input Layer: [batch_size, state_size]
|
||||
├── Linear Layer 1: state_size → hidden_size (384)
|
||||
├── ReLU Activation
|
||||
├── LSTM Layers: 2 stacked layers with hidden_size=384
|
||||
├── Multi-Head Attention: 4 attention heads
|
||||
├── Linear Layer 2: hidden_size → hidden_size
|
||||
├── ReLU Activation
|
||||
├── Linear Layer 3: hidden_size → action_size (4)
|
||||
└── Output: [batch_size, action_size] (Q-values for each action)
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
**Inputs:**
|
||||
- `state`: Current market state representation [batch_size, state_size]
|
||||
- Price features (normalized prices, returns, volatility)
|
||||
- Technical indicators (RSI, MACD, Stochastic,
|
@ -1,17 +1,9 @@
|
||||
pip install torch-tb-profiler
|
||||
|
||||
|
||||
|
||||
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
|
||||
|
||||
|
||||
|
||||
do we call the trading api proerly when in live mode? in live mode - also validate the current ballance. also ensure trades are executed after executing the orders by checking the open orders.
|
||||
|
||||
our trading data chart (in tensorboard) does not properly displayed - the candles seems displayed multiple times but shifted in time. we also do not correctly show the buy/sell evens on the time axis. we do not show the predicted price on the chart.
|
||||
|
||||
|
||||
|
||||
|
||||
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
|
||||
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
|
||||
@ -29,37 +21,4 @@ Backend tkagg is interactive backend. Turning interactive mode on.
|
||||
|
||||
|
||||
|
||||
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
||||
|
||||
|
||||
----------------
|
||||
|
||||
do we train the price prediction model by using old known candles and masking the latest to make him guess the next and then backpropagating the next already known candle ? like a transformer (gpt2) would do? or we use RL for that as well?
|
||||
it seems the model is not learning a lot. we keep hovering about the same starting balance even after some time in training in continious mode
|
||||
|
||||
|
||||
it seems we may need another NN model down the loop jut to predict the extremums of the price.
|
||||
we may have to include a mechanism to calculate the extremums of the price retrospectively and to use that to bootstrap pre-train the model.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Why Performance Might Be Stagnating
|
||||
Several factors could explain why the model isn't improving significantly during training:
|
||||
Insufficient Model Capacity for Price Prediction: While the price prediction model has 267,663 parameters, financial time series prediction is extremely challenging. The market may have patterns that are too complex or too random for the current model to capture effectively.
|
||||
Overfitting to Training Data: The model might be memorizing patterns in the training data that don't generalize to new market conditions.
|
||||
Transformer-LSTM Redundancy in DQN: Your DQN model uses both a transformer and an LSTM, which might be redundant. Both are designed to capture sequential dependencies, and having both could lead to overfitting or training instability.
|
||||
Imbalanced Parameter Distribution: 64.5% of your DQN parameters are in the transformer component, which might be excessive for the task.
|
||||
Reward Function Issues: The reward function might not be properly aligned with profitable trading strategies, or it might be too sparse to provide meaningful learning signals.
|
||||
Suggested Improvements
|
||||
1. Enhance Price Prediction Training
|
||||
2. Simplify the DQN Architecture
|
||||
Consider creating a more streamlined DQN model:
|
||||
3. Improve the Reward Function
|
||||
Make sure your reward function provides meaningful signals for learning:
|
||||
4. Implement Curriculum Learning
|
||||
Start with simpler market conditions and gradually increase complexity:
|
||||
Conclusion
|
||||
The issue appears to be a combination of model complexity, potential overfitting, and possibly insufficient learning signals from the reward function. By simplifying the DQN architecture (particularly reducing the transformer component), improving the price prediction training, and enhancing the reward function, you should see better learning progress.
|
||||
Would you like me to implement any of these specific improvements to your codebase?
|
||||
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
@ -1,362 +0,0 @@
|
||||
class MexcTradingClient:
|
||||
def __init__(self, api_key, secret_key, symbol="ETH/USDT", leverage=50):
|
||||
self.client = ccxt.mexc({
|
||||
'apiKey': api_key,
|
||||
'secret': secret_key,
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
self.symbol = symbol
|
||||
self.leverage = leverage
|
||||
self.position = 'flat'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.trades = []
|
||||
|
||||
def initialize_mexc_client(self, api_key, api_secret):
|
||||
"""Initialize the MEXC API client"""
|
||||
try:
|
||||
from mexc_sdk import Spot
|
||||
self.mexc_client = Spot(api_key=api_key, api_secret=api_secret)
|
||||
# Test connection
|
||||
self.mexc_client.ping()
|
||||
logger.info("MEXC API client initialized successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MEXC API client: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def fetch_account_balance(self):
|
||||
"""Fetch actual account balance from MEXC API"""
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, use simulated balance
|
||||
return self.balance
|
||||
|
||||
try:
|
||||
account_info = self.mexc_client.accountInfo()
|
||||
if 'balances' in account_info:
|
||||
# Find USDT balance
|
||||
for asset in account_info['balances']:
|
||||
if asset['asset'] == 'USDT':
|
||||
return float(asset['free'])
|
||||
|
||||
logger.warning("Could not find USDT balance, using current simulated balance")
|
||||
return self.balance
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching account balance: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# Fallback to simulated balance in case of API error
|
||||
return self.balance
|
||||
|
||||
async def fetch_open_positions(self):
|
||||
"""Fetch actual open positions from MEXC API"""
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, return current simulated position
|
||||
return [{
|
||||
'symbol': 'ETH/USDT',
|
||||
'positionSide': 'LONG' if self.position == 'long' else 'SHORT' if self.position == 'short' else 'NONE',
|
||||
'positionAmt': self.position_size / self.current_price if self.position != 'flat' else 0,
|
||||
'entryPrice': self.entry_price,
|
||||
'unrealizedProfit': self.calculate_unrealized_pnl()
|
||||
}] if self.position != 'flat' else []
|
||||
|
||||
try:
|
||||
# Fetch open positions
|
||||
positions = self.mexc_client.openOrders('ETH/USDT')
|
||||
return positions
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching open positions: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# Fallback to simulated positions in case of API error
|
||||
return []
|
||||
|
||||
def calculate_unrealized_pnl(self):
|
||||
"""Calculate unrealized PnL for the current position"""
|
||||
if self.position == 'flat':
|
||||
return 0.0
|
||||
|
||||
position_value = self.position_size / self.entry_price
|
||||
|
||||
if self.position == 'long':
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||
|
||||
# Apply leverage
|
||||
pnl_percent *= self.leverage
|
||||
|
||||
return position_value * pnl_percent / 100
|
||||
|
||||
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
|
||||
"""Open a new position using MEXC API in live mode, or simulate in demo mode"""
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, simulate opening a position
|
||||
self.position = position_type
|
||||
self.position_size = size
|
||||
self.entry_price = entry_price
|
||||
self.entry_index = self.current_step
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
logger.info(f"DEMO: Opened {position_type.upper()} position at {entry_price} | " +
|
||||
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
|
||||
return True
|
||||
|
||||
try:
|
||||
# In live mode, place actual orders via API
|
||||
symbol = "ETHUSDT" # Format required by MEXC
|
||||
side = "BUY" if position_type == 'long' else "SELL"
|
||||
|
||||
# Calculate quantity based on size and price
|
||||
quantity = size / entry_price
|
||||
|
||||
# Place main order
|
||||
order_result = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
orderType="MARKET",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"leverage": self.leverage,
|
||||
"newOrderRespType": "FULL"
|
||||
}
|
||||
)
|
||||
|
||||
# Check if order executed
|
||||
if order_result.get('status') == 'FILLED':
|
||||
actual_entry_price = float(order_result.get('price', entry_price))
|
||||
|
||||
# Place stop loss order
|
||||
sl_order = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side="SELL" if position_type == 'long' else "BUY",
|
||||
orderType="STOP_LOSS",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"stopPrice": stop_loss,
|
||||
"newOrderRespType": "ACK"
|
||||
}
|
||||
)
|
||||
|
||||
# Place take profit order
|
||||
tp_order = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side="SELL" if position_type == 'long' else "BUY",
|
||||
orderType="TAKE_PROFIT",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"stopPrice": take_profit,
|
||||
"newOrderRespType": "ACK"
|
||||
}
|
||||
)
|
||||
|
||||
# Update local state
|
||||
self.position = position_type
|
||||
self.position_size = size
|
||||
self.entry_price = actual_entry_price
|
||||
self.entry_index = self.current_step
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
# Track orders
|
||||
self.open_orders.extend([sl_order, tp_order])
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"LIVE: Opened {position_type.upper()} position at {actual_entry_price} | " +
|
||||
f"Size: ${size:.2f} | SL: {stop_loss:.2f} | TP: {take_profit:.2f}")
|
||||
return True
|
||||
|
||||
else:
|
||||
logger.error(f"Failed to execute order: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def close_position(self, reason="manual_close"):
|
||||
"""Close the current position using MEXC API in live mode, or simulate in demo mode"""
|
||||
if self.position == 'flat':
|
||||
logger.info("No position to close")
|
||||
return False
|
||||
|
||||
if self.demo or not self.mexc_client:
|
||||
# In demo mode, simulate closing a position
|
||||
position_type = self.position
|
||||
entry_price = self.entry_price
|
||||
exit_price = self.current_price
|
||||
position_size = self.position_size
|
||||
|
||||
# Calculate PnL
|
||||
if position_type == 'long':
|
||||
pnl_percent = (exit_price - entry_price) / entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (entry_price - exit_price) / entry_price * 100
|
||||
|
||||
# Apply leverage
|
||||
pnl_percent *= self.leverage
|
||||
|
||||
# Calculate actual PnL
|
||||
pnl_dollar = position_size * pnl_percent / 100
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(position_size)
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
self.episode_pnl += pnl_dollar
|
||||
|
||||
# Update max drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': position_type,
|
||||
'entry': entry_price,
|
||||
'exit': exit_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
'market_direction': self.get_market_direction(),
|
||||
'reason': reason,
|
||||
'leverage': self.leverage
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"DEMO: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
# In live mode, close position via API
|
||||
symbol = "ETHUSDT"
|
||||
position_info = await self.fetch_open_positions()
|
||||
|
||||
if not position_info:
|
||||
logger.warning("No open positions found to close")
|
||||
self.position = 'flat'
|
||||
return False
|
||||
|
||||
# First, cancel any existing stop loss/take profit orders
|
||||
try:
|
||||
self.mexc_client.cancelOpenOrders(symbol)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error canceling open orders: {e}")
|
||||
|
||||
# Close the position with a market order
|
||||
position_type = self.position
|
||||
side = "SELL" if position_type == 'long' else "BUY"
|
||||
quantity = self.position_size / self.current_price
|
||||
|
||||
# Execute order
|
||||
order_result = self.mexc_client.newOrder(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
orderType="MARKET",
|
||||
quantity=quantity,
|
||||
options={
|
||||
"newOrderRespType": "FULL"
|
||||
}
|
||||
)
|
||||
|
||||
# Check if order executed
|
||||
if order_result.get('status') == 'FILLED':
|
||||
exit_price = float(order_result.get('price', self.current_price))
|
||||
entry_price = self.entry_price
|
||||
position_size = self.position_size
|
||||
|
||||
# Calculate PnL
|
||||
if position_type == 'long':
|
||||
pnl_percent = (exit_price - entry_price) / entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (entry_price - exit_price) / entry_price * 100
|
||||
|
||||
# Apply leverage
|
||||
pnl_percent *= self.leverage
|
||||
|
||||
# Calculate actual PnL
|
||||
pnl_dollar = position_size * pnl_percent / 100
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(position_size)
|
||||
|
||||
# Update balance from API
|
||||
self.balance = await self.fetch_account_balance()
|
||||
self.total_pnl += pnl_dollar
|
||||
self.episode_pnl += pnl_dollar
|
||||
|
||||
# Update max drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': position_type,
|
||||
'entry': entry_price,
|
||||
'exit': exit_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
'market_direction': self.get_market_direction(),
|
||||
'reason': reason,
|
||||
'leverage': self.leverage,
|
||||
'order_id': order_result.get('orderId')
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Track order history
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"LIVE: Closed {position_type} at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to close position: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 0, "timestamp": "2025-03-12T00:23:19.125190"}
|
||||
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 30, "timestamp": "2025-03-10T17:57:19.913481"}
|
@ -1,378 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
class EnhancedPricePredictionModel(nn.Module):
|
||||
def __init__(self, input_dim=2, hidden_dim=256, num_layers=3, output_dim=5, num_timeframes=3):
|
||||
super(EnhancedPricePredictionModel, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_timeframes = num_timeframes
|
||||
|
||||
# Separate LSTM for each timeframe
|
||||
self.timeframe_lstms = nn.ModuleList([
|
||||
nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)
|
||||
for _ in range(num_timeframes)
|
||||
])
|
||||
|
||||
# Cross-timeframe attention
|
||||
self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
|
||||
|
||||
# Self-attention for each timeframe
|
||||
self.self_attentions = nn.ModuleList([
|
||||
nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
|
||||
for _ in range(num_timeframes)
|
||||
])
|
||||
|
||||
# Timeframe fusion layer
|
||||
self.fusion_layer = nn.Sequential(
|
||||
nn.Linear(hidden_dim * num_timeframes, hidden_dim * 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim * 2, hidden_dim)
|
||||
)
|
||||
|
||||
# Fully connected layer for price prediction
|
||||
self.price_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
# Fully connected layer for extrema prediction (high and low points)
|
||||
self.extrema_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
|
||||
)
|
||||
|
||||
# Volume prediction layer
|
||||
self.volume_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x_list):
|
||||
# x_list is a list of tensors, one for each timeframe
|
||||
# Each x shape: (batch_size, seq_len, input_dim)
|
||||
|
||||
# Process each timeframe with its own LSTM
|
||||
lstm_outputs = []
|
||||
for i, x in enumerate(x_list):
|
||||
lstm_out, _ = self.timeframe_lstms[i](x) # lstm_out: (batch_size, seq_len, hidden_dim)
|
||||
lstm_outputs.append(lstm_out)
|
||||
|
||||
# Apply self-attention to each timeframe
|
||||
attn_outputs = []
|
||||
for i, lstm_out in enumerate(lstm_outputs):
|
||||
attn_output, _ = self.self_attentions[i](lstm_out, lstm_out, lstm_out)
|
||||
attn_outputs.append(attn_output[:, -1, :]) # Use the last time step
|
||||
|
||||
# Concatenate all timeframe representations
|
||||
combined = torch.cat(attn_outputs, dim=1) # (batch_size, hidden_dim * num_timeframes)
|
||||
|
||||
# Fuse timeframe information
|
||||
fused = self.fusion_layer(combined) # (batch_size, hidden_dim)
|
||||
|
||||
# Price prediction
|
||||
price_pred = self.price_fc(fused)
|
||||
|
||||
# Extrema prediction
|
||||
extrema_logits = self.extrema_fc(fused)
|
||||
|
||||
# Volume prediction
|
||||
volume_pred = self.volume_fc(fused)
|
||||
|
||||
return price_pred, extrema_logits, volume_pred
|
||||
|
||||
class EnhancedDQN(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim=512):
|
||||
super(EnhancedDQN, self).__init__()
|
||||
|
||||
# Feature extraction layers with increased capacity
|
||||
self.feature_extraction = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
# Advantage stream with increased capacity
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, action_dim)
|
||||
)
|
||||
|
||||
# Value stream with increased capacity
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
)
|
||||
|
||||
# Enhanced transformer for temporal dependencies
|
||||
encoder_layers = TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=8,
|
||||
dim_feedforward=hidden_dim*4,
|
||||
dropout=0.1,
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=3)
|
||||
|
||||
# LSTM for sequential decision making with increased capacity
|
||||
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.1)
|
||||
|
||||
# Final layers with increased capacity
|
||||
self.final_layers = nn.Sequential(
|
||||
nn.Linear(hidden_dim*2, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, action_dim)
|
||||
)
|
||||
|
||||
# Market regime classification layer
|
||||
self.market_regime_classifier = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, 3) # 3 regimes: trending, ranging, volatile
|
||||
)
|
||||
|
||||
def forward(self, state, hidden=None):
|
||||
# Extract features
|
||||
features = self.feature_extraction(state)
|
||||
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
|
||||
|
||||
# Transformer processing
|
||||
transformer_out = self.transformer(features)
|
||||
|
||||
# LSTM processing
|
||||
lstm_out, lstm_hidden = self.lstm(transformer_out)
|
||||
|
||||
# Dueling architecture
|
||||
advantage = self.advantage_stream(features.squeeze(1))
|
||||
value = self.value_stream(features.squeeze(1))
|
||||
|
||||
# Combine transformer, LSTM and dueling outputs
|
||||
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
|
||||
q_values = self.final_layers(combined)
|
||||
|
||||
# Market regime classification
|
||||
market_regime = self.market_regime_classifier(transformer_out.squeeze(1))
|
||||
|
||||
# Dueling Q-value computation
|
||||
dueling_q = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
# Final Q-values are a weighted combination of the dueling Q-values and the direct Q-values
|
||||
# This allows the model to use either approach depending on the situation
|
||||
q_values = 0.5 * dueling_q + 0.5 * q_values
|
||||
|
||||
return q_values, lstm_hidden, market_regime
|
||||
|
||||
def count_parameters(model):
|
||||
total_params = 0
|
||||
layer_params = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
param_count = param.numel()
|
||||
total_params += param_count
|
||||
layer_params[name] = (param_count, param.shape)
|
||||
|
||||
return total_params, layer_params
|
||||
|
||||
def main():
|
||||
# Initialize the original models for comparison
|
||||
original_price_model = PricePredictionModel()
|
||||
original_price_total_params, _ = count_parameters(original_price_model)
|
||||
|
||||
state_dim = 50
|
||||
action_dim = 3
|
||||
original_dqn_model = DQN(state_dim=state_dim, action_dim=action_dim)
|
||||
original_dqn_total_params, _ = count_parameters(original_dqn_model)
|
||||
|
||||
# Initialize the enhanced models
|
||||
enhanced_price_model = EnhancedPricePredictionModel(num_timeframes=3)
|
||||
enhanced_price_total_params, enhanced_price_layer_params = count_parameters(enhanced_price_model)
|
||||
|
||||
# Increase state dimension to accommodate multiple timeframes
|
||||
enhanced_state_dim = 100 # Increased from 50 to accommodate more features
|
||||
enhanced_dqn_model = EnhancedDQN(state_dim=enhanced_state_dim, action_dim=action_dim)
|
||||
enhanced_dqn_total_params, enhanced_dqn_layer_params = count_parameters(enhanced_dqn_model)
|
||||
|
||||
# Print comparison
|
||||
print("=== MODEL SIZE COMPARISON ===")
|
||||
print(f"Original Price Prediction Model: {original_price_total_params:,} parameters")
|
||||
print(f"Enhanced Price Prediction Model: {enhanced_price_total_params:,} parameters")
|
||||
print(f"Growth Factor: {enhanced_price_total_params / original_price_total_params:.2f}x\n")
|
||||
|
||||
print(f"Original DQN Model: {original_dqn_total_params:,} parameters")
|
||||
print(f"Enhanced DQN Model: {enhanced_dqn_total_params:,} parameters")
|
||||
print(f"Growth Factor: {enhanced_dqn_total_params / original_dqn_total_params:.2f}x\n")
|
||||
|
||||
print(f"Total Original Models: {original_price_total_params + original_dqn_total_params:,} parameters")
|
||||
print(f"Total Enhanced Models: {enhanced_price_total_params + enhanced_dqn_total_params:,} parameters")
|
||||
print(f"Overall Growth Factor: {(enhanced_price_total_params + enhanced_dqn_total_params) / (original_price_total_params + original_dqn_total_params):.2f}x\n")
|
||||
|
||||
# Print VRAM usage estimate (rough approximation)
|
||||
bytes_per_param = 4 # 4 bytes for float32
|
||||
original_vram_mb = (original_price_total_params + original_dqn_total_params) * bytes_per_param / (1024 * 1024)
|
||||
enhanced_vram_mb = (enhanced_price_total_params + enhanced_dqn_total_params) * bytes_per_param / (1024 * 1024)
|
||||
|
||||
print("=== ESTIMATED VRAM USAGE ===")
|
||||
print(f"Original Models: {original_vram_mb:.2f} MB")
|
||||
print(f"Enhanced Models: {enhanced_vram_mb:.2f} MB")
|
||||
print(f"Available VRAM: 8,192 MB (8 GB)")
|
||||
print(f"VRAM Utilization: {enhanced_vram_mb / 8192 * 100:.2f}%\n")
|
||||
|
||||
# Print detailed breakdown of enhanced models
|
||||
print("=== ENHANCED PRICE PREDICTION MODEL BREAKDOWN ===")
|
||||
|
||||
# Group parameters by component
|
||||
timeframe_lstm_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if "timeframe_lstms" in name)
|
||||
attention_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if "attention" in name)
|
||||
fusion_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if "fusion" in name)
|
||||
output_params = sum(count for name, (count, _) in enhanced_price_layer_params.items() if any(x in name for x in ["price_fc", "extrema_fc", "volume_fc"]))
|
||||
|
||||
print(f"Timeframe LSTMs: {timeframe_lstm_params:,} parameters ({timeframe_lstm_params/enhanced_price_total_params*100:.1f}%)")
|
||||
print(f"Attention Mechanisms: {attention_params:,} parameters ({attention_params/enhanced_price_total_params*100:.1f}%)")
|
||||
print(f"Fusion Layer: {fusion_params:,} parameters ({fusion_params/enhanced_price_total_params*100:.1f}%)")
|
||||
print(f"Output Layers: {output_params:,} parameters ({output_params/enhanced_price_total_params*100:.1f}%)\n")
|
||||
|
||||
print("=== ENHANCED DQN MODEL BREAKDOWN ===")
|
||||
|
||||
# Group parameters by component
|
||||
feature_extraction_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "feature_extraction" in name)
|
||||
advantage_value_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "advantage_stream" in name or "value_stream" in name)
|
||||
transformer_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "transformer" in name)
|
||||
lstm_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "lstm" in name and "transformer" not in name)
|
||||
final_layers_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "final_layers" in name)
|
||||
market_regime_params = sum(count for name, (count, _) in enhanced_dqn_layer_params.items() if "market_regime" in name)
|
||||
|
||||
print(f"Feature Extraction: {feature_extraction_params:,} parameters ({feature_extraction_params/enhanced_dqn_total_params*100:.1f}%)")
|
||||
print(f"Advantage & Value Streams: {advantage_value_params:,} parameters ({advantage_value_params/enhanced_dqn_total_params*100:.1f}%)")
|
||||
print(f"Transformer: {transformer_params:,} parameters ({transformer_params/enhanced_dqn_total_params*100:.1f}%)")
|
||||
print(f"LSTM: {lstm_params:,} parameters ({lstm_params/enhanced_dqn_total_params*100:.1f}%)")
|
||||
print(f"Final Layers: {final_layers_params:,} parameters ({final_layers_params/enhanced_dqn_total_params*100:.1f}%)")
|
||||
print(f"Market Regime Classifier: {market_regime_params:,} parameters ({market_regime_params/enhanced_dqn_total_params*100:.1f}%)")
|
||||
|
||||
# Keep the original models for comparison
|
||||
class PricePredictionModel(nn.Module):
|
||||
def __init__(self, input_dim=2, hidden_dim=128, num_layers=2, output_dim=5):
|
||||
super(PricePredictionModel, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# LSTM layers
|
||||
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
|
||||
|
||||
# Self-attention mechanism
|
||||
self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
|
||||
|
||||
# Fully connected layer for price prediction
|
||||
self.price_fc = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
# Fully connected layer for extrema prediction (high and low points)
|
||||
self.extrema_fc = nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (batch_size, seq_len, input_dim)
|
||||
|
||||
# LSTM forward pass
|
||||
lstm_out, _ = self.lstm(x) # lstm_out: (batch_size, seq_len, hidden_dim)
|
||||
|
||||
# Self-attention
|
||||
attn_output, _ = self.attention(lstm_out, lstm_out, lstm_out)
|
||||
|
||||
# Price prediction
|
||||
price_pred = self.price_fc(attn_output[:, -1, :]) # Use the last time step
|
||||
|
||||
# Extrema prediction
|
||||
extrema_logits = self.extrema_fc(attn_output[:, -1, :])
|
||||
|
||||
return price_pred, extrema_logits
|
||||
|
||||
class DQN(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
# Feature extraction layers
|
||||
self.feature_extraction = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
# Advantage stream
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, action_dim)
|
||||
)
|
||||
|
||||
# Value stream
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
# Transformer for temporal dependencies
|
||||
encoder_layers = TransformerEncoderLayer(d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||
|
||||
# LSTM for sequential decision making
|
||||
self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
|
||||
|
||||
# Final layers
|
||||
self.final_layers = nn.Sequential(
|
||||
nn.Linear(hidden_dim*2, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim, action_dim)
|
||||
)
|
||||
|
||||
def forward(self, state, hidden=None):
|
||||
# Extract features
|
||||
features = self.feature_extraction(state)
|
||||
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
|
||||
|
||||
# Transformer processing
|
||||
transformer_out = self.transformer(features)
|
||||
|
||||
# LSTM processing
|
||||
lstm_out, lstm_hidden = self.lstm(transformer_out)
|
||||
|
||||
# Dueling architecture
|
||||
advantage = self.advantage_stream(features.squeeze(1))
|
||||
value = self.value_stream(features.squeeze(1))
|
||||
|
||||
# Combine transformer, LSTM and dueling outputs
|
||||
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
|
||||
q_values = self.final_layers(combined)
|
||||
|
||||
# Dueling Q-value computation
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
return q_values, lstm_hidden
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,319 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger('trading_bot')
|
||||
|
||||
class OHLCVCache:
|
||||
"""
|
||||
A simple cache for OHLCV data from exchanges.
|
||||
Stores data in a structured format and provides backup when exchange is unavailable.
|
||||
"""
|
||||
def __init__(self, cache_dir="cache", max_age_hours=24):
|
||||
"""
|
||||
Initialize the OHLCV cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to store cache files
|
||||
max_age_hours: Maximum age of cached data in hours before considered stale
|
||||
"""
|
||||
self.cache_dir = cache_dir
|
||||
self.max_age_seconds = max_age_hours * 3600
|
||||
|
||||
# Create cache directory if it doesn't exist
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# In-memory cache for faster access
|
||||
self.memory_cache = {}
|
||||
|
||||
def _get_cache_filename(self, symbol, timeframe):
|
||||
"""Generate a standardized filename for the cache file"""
|
||||
# Replace / with _ in symbol name (e.g., ETH/USDT -> ETH_USDT)
|
||||
safe_symbol = symbol.replace('/', '_')
|
||||
return os.path.join(self.cache_dir, f"{safe_symbol}_{timeframe}.json")
|
||||
|
||||
def save(self, data, symbol, timeframe):
|
||||
"""
|
||||
Save OHLCV data to cache.
|
||||
|
||||
Args:
|
||||
data: List of dictionaries containing OHLCV data
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
"""
|
||||
if not data:
|
||||
logger.warning(f"No data to cache for {symbol} ({timeframe})")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Convert data to a serializable format
|
||||
serializable_data = []
|
||||
for candle in data:
|
||||
serializable_data.append({
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': float(candle['open']),
|
||||
'high': float(candle['high']),
|
||||
'low': float(candle['low']),
|
||||
'close': float(candle['close']),
|
||||
'volume': float(candle['volume'])
|
||||
})
|
||||
|
||||
# Create cache entry with metadata
|
||||
cache_entry = {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'last_updated': int(time.time()),
|
||||
'data': serializable_data
|
||||
}
|
||||
|
||||
# Save to file
|
||||
filename = self._get_cache_filename(symbol, timeframe)
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(cache_entry, f)
|
||||
|
||||
# Update in-memory cache
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
self.memory_cache[cache_key] = cache_entry
|
||||
|
||||
logger.info(f"Cached {len(data)} candles for {symbol} ({timeframe})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving data to cache: {e}")
|
||||
return False
|
||||
|
||||
def load(self, symbol, timeframe, max_age_override=None):
|
||||
"""
|
||||
Load OHLCV data from cache.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
max_age_override: Override the default max age (in seconds)
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing OHLCV data, or None if cache is missing or stale
|
||||
"""
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
max_age = max_age_override if max_age_override is not None else self.max_age_seconds
|
||||
|
||||
try:
|
||||
# Check in-memory cache first
|
||||
if cache_key in self.memory_cache:
|
||||
cache_entry = self.memory_cache[cache_key]
|
||||
|
||||
# Check if cache is fresh
|
||||
cache_age = int(time.time()) - cache_entry['last_updated']
|
||||
if cache_age <= max_age:
|
||||
logger.info(f"Using in-memory cache for {symbol} ({timeframe}), age: {cache_age//60} minutes")
|
||||
return cache_entry['data']
|
||||
|
||||
# Check file cache
|
||||
filename = self._get_cache_filename(symbol, timeframe)
|
||||
if not os.path.exists(filename):
|
||||
logger.info(f"No cache file found for {symbol} ({timeframe})")
|
||||
return None
|
||||
|
||||
# Load cache file
|
||||
with open(filename, 'r') as f:
|
||||
cache_entry = json.load(f)
|
||||
|
||||
# Check if cache is fresh
|
||||
cache_age = int(time.time()) - cache_entry['last_updated']
|
||||
if cache_age > max_age:
|
||||
logger.info(f"Cache for {symbol} ({timeframe}) is stale ({cache_age//60} minutes old)")
|
||||
return None
|
||||
|
||||
# Update in-memory cache
|
||||
self.memory_cache[cache_key] = cache_entry
|
||||
|
||||
logger.info(f"Loaded {len(cache_entry['data'])} candles from cache for {symbol} ({timeframe})")
|
||||
return cache_entry['data']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading data from cache: {e}")
|
||||
return None
|
||||
|
||||
def append(self, new_candle, symbol, timeframe):
|
||||
"""
|
||||
Append a new candle to the cached data.
|
||||
|
||||
Args:
|
||||
new_candle: Dictionary containing a single OHLCV candle
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
try:
|
||||
# Load existing data
|
||||
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for append
|
||||
|
||||
if data is None:
|
||||
data = []
|
||||
|
||||
# Check if the candle already exists (same timestamp)
|
||||
for i, candle in enumerate(data):
|
||||
if candle['timestamp'] == new_candle['timestamp']:
|
||||
# Update existing candle
|
||||
data[i] = {
|
||||
'timestamp': new_candle['timestamp'],
|
||||
'open': float(new_candle['open']),
|
||||
'high': float(new_candle['high']),
|
||||
'low': float(new_candle['low']),
|
||||
'close': float(new_candle['close']),
|
||||
'volume': float(new_candle['volume'])
|
||||
}
|
||||
# Save updated data
|
||||
return self.save(data, symbol, timeframe)
|
||||
|
||||
# Append new candle
|
||||
data.append({
|
||||
'timestamp': new_candle['timestamp'],
|
||||
'open': float(new_candle['open']),
|
||||
'high': float(new_candle['high']),
|
||||
'low': float(new_candle['low']),
|
||||
'close': float(new_candle['close']),
|
||||
'volume': float(new_candle['volume'])
|
||||
})
|
||||
|
||||
# Save updated data
|
||||
return self.save(data, symbol, timeframe)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error appending candle to cache: {e}")
|
||||
return False
|
||||
|
||||
def get_latest_timestamp(self, symbol, timeframe):
|
||||
"""
|
||||
Get the timestamp of the most recent candle in the cache.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
|
||||
Returns:
|
||||
Timestamp (milliseconds) of the most recent candle, or None if cache is empty
|
||||
"""
|
||||
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for this check
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Find the most recent timestamp
|
||||
latest_timestamp = max(candle['timestamp'] for candle in data)
|
||||
return latest_timestamp
|
||||
|
||||
def clear(self, symbol=None, timeframe=None):
|
||||
"""
|
||||
Clear cache for a specific symbol and timeframe, or all cache if not specified.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT'), or None to clear all symbols
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h'), or None to clear all timeframes
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
count = 0
|
||||
|
||||
try:
|
||||
if symbol and timeframe:
|
||||
# Clear specific cache
|
||||
filename = self._get_cache_filename(symbol, timeframe)
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
count = 1
|
||||
|
||||
# Clear from memory cache
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
else:
|
||||
# Clear all matching caches
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
file_path = os.path.join(self.cache_dir, filename)
|
||||
|
||||
# Skip directories
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
|
||||
# Check if file matches the filter
|
||||
should_delete = True
|
||||
|
||||
if symbol:
|
||||
safe_symbol = symbol.replace('/', '_')
|
||||
if not filename.startswith(f"{safe_symbol}_"):
|
||||
should_delete = False
|
||||
|
||||
if timeframe:
|
||||
if not filename.endswith(f"_{timeframe}.json"):
|
||||
should_delete = False
|
||||
|
||||
# Delete file if it matches the filter
|
||||
if should_delete:
|
||||
os.remove(file_path)
|
||||
count += 1
|
||||
|
||||
# Clear memory cache
|
||||
keys_to_delete = []
|
||||
for cache_key in self.memory_cache:
|
||||
should_delete = True
|
||||
|
||||
if symbol:
|
||||
if not cache_key.startswith(f"{symbol}_"):
|
||||
should_delete = False
|
||||
|
||||
if timeframe:
|
||||
if not cache_key.endswith(f"_{timeframe}"):
|
||||
should_delete = False
|
||||
|
||||
if should_delete:
|
||||
keys_to_delete.append(cache_key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
del self.memory_cache[key]
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing cache: {e}")
|
||||
return 0
|
||||
|
||||
def to_dataframe(self, symbol, timeframe):
|
||||
"""
|
||||
Convert cached OHLCV data to a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with OHLCV data, or None if cache is missing
|
||||
"""
|
||||
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for conversion
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
|
||||
# Set datetime as index
|
||||
df.set_index('datetime', inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
# Create a global instance for easy access
|
||||
ohlcv_cache = OHLCVCache()
|
@ -1,449 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
class EnhancedPricePredictionModel(nn.Module):
|
||||
def __init__(self, input_dim=2, hidden_dim=256, num_layers=3, output_dim=5, num_timeframes=3):
|
||||
super(EnhancedPricePredictionModel, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_timeframes = num_timeframes
|
||||
|
||||
# Separate LSTM for each timeframe
|
||||
self.timeframe_lstms = nn.ModuleList([
|
||||
nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)
|
||||
for _ in range(num_timeframes)
|
||||
])
|
||||
|
||||
# Cross-timeframe attention
|
||||
self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
|
||||
|
||||
# Self-attention for each timeframe
|
||||
self.self_attentions = nn.ModuleList([
|
||||
nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True, dropout=0.1)
|
||||
for _ in range(num_timeframes)
|
||||
])
|
||||
|
||||
# Timeframe fusion layer
|
||||
self.fusion_layer = nn.Sequential(
|
||||
nn.Linear(hidden_dim * num_timeframes, hidden_dim * 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim * 2, hidden_dim)
|
||||
)
|
||||
|
||||
# Fully connected layer for price prediction
|
||||
self.price_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
# Fully connected layer for extrema prediction (high and low points)
|
||||
self.extrema_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, 10) # 5 time steps, 2 classes (high/low) each
|
||||
)
|
||||
|
||||
# Volume prediction layer
|
||||
self.volume_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x_list):
|
||||
# x_list is a list of tensors, one for each timeframe
|
||||
# Each x shape: (batch_size, seq_len, input_dim)
|
||||
|
||||
# Process each timeframe with its own LSTM
|
||||
lstm_outputs = []
|
||||
for i, x in enumerate(x_list):
|
||||
lstm_out, _ = self.timeframe_lstms[i](x) # lstm_out: (batch_size, seq_len, hidden_dim)
|
||||
lstm_outputs.append(lstm_out)
|
||||
|
||||
# Apply self-attention to each timeframe
|
||||
attn_outputs = []
|
||||
for i, lstm_out in enumerate(lstm_outputs):
|
||||
attn_output, _ = self.self_attentions[i](lstm_out, lstm_out, lstm_out)
|
||||
attn_outputs.append(attn_output[:, -1, :]) # Use the last time step
|
||||
|
||||
# Concatenate all timeframe representations
|
||||
combined = torch.cat(attn_outputs, dim=1) # (batch_size, hidden_dim * num_timeframes)
|
||||
|
||||
# Fuse timeframe information
|
||||
fused = self.fusion_layer(combined) # (batch_size, hidden_dim)
|
||||
|
||||
# Price prediction
|
||||
price_pred = self.price_fc(fused)
|
||||
|
||||
# Extrema prediction
|
||||
extrema_logits = self.extrema_fc(fused)
|
||||
|
||||
# Volume prediction
|
||||
volume_pred = self.volume_fc(fused)
|
||||
|
||||
return price_pred, extrema_logits, volume_pred
|
||||
|
||||
class EnhancedDQN(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim=512):
|
||||
super(EnhancedDQN, self).__init__()
|
||||
|
||||
# Feature extraction layers with increased capacity
|
||||
self.feature_extraction = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
# Advantage stream with increased capacity
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, action_dim)
|
||||
)
|
||||
|
||||
# Value stream with increased capacity
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
)
|
||||
|
||||
# Enhanced transformer for temporal dependencies
|
||||
encoder_layers = TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=8,
|
||||
dim_feedforward=hidden_dim*4,
|
||||
dropout=0.1,
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=3)
|
||||
|
||||
# LSTM for sequential decision making with increased capacity
|
||||
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.1)
|
||||
|
||||
# Final layers with increased capacity
|
||||
self.final_layers = nn.Sequential(
|
||||
nn.Linear(hidden_dim*2, hidden_dim),
|
||||
nn.LeakyReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, action_dim)
|
||||
)
|
||||
|
||||
# Market regime classification layer
|
||||
self.market_regime_classifier = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(hidden_dim // 2, 3) # 3 regimes: trending, ranging, volatile
|
||||
)
|
||||
|
||||
def forward(self, state, hidden=None):
|
||||
# Extract features
|
||||
features = self.feature_extraction(state)
|
||||
features = features.unsqueeze(1) # Add sequence dimension for transformer/LSTM
|
||||
|
||||
# Transformer processing
|
||||
transformer_out = self.transformer(features)
|
||||
|
||||
# LSTM processing
|
||||
lstm_out, lstm_hidden = self.lstm(transformer_out)
|
||||
|
||||
# Dueling architecture
|
||||
advantage = self.advantage_stream(features.squeeze(1))
|
||||
value = self.value_stream(features.squeeze(1))
|
||||
|
||||
# Combine transformer, LSTM and dueling outputs
|
||||
combined = torch.cat([transformer_out.squeeze(1), lstm_out.squeeze(1)], dim=1)
|
||||
q_values = self.final_layers(combined)
|
||||
|
||||
# Market regime classification
|
||||
market_regime = self.market_regime_classifier(transformer_out.squeeze(1))
|
||||
|
||||
# Dueling Q-value computation
|
||||
dueling_q = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
# Final Q-values are a weighted combination of the dueling Q-values and the direct Q-values
|
||||
# This allows the model to use either approach depending on the situation
|
||||
q_values = 0.5 * dueling_q + 0.5 * q_values
|
||||
|
||||
return q_values, lstm_hidden, market_regime
|
||||
|
||||
class EnhancedReplayBuffer:
|
||||
"""Enhanced replay buffer with prioritized experience replay and n-step returns"""
|
||||
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, n_step=3, gamma=0.99):
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.position = 0
|
||||
self.priorities = torch.zeros(capacity)
|
||||
self.alpha = alpha # Priority exponent
|
||||
self.beta = beta # Importance sampling weight
|
||||
self.beta_increment = beta_increment # Beta annealing
|
||||
self.n_step = n_step # n-step returns
|
||||
self.gamma = gamma # Discount factor
|
||||
self.n_step_buffer = []
|
||||
self.max_priority = 1.0
|
||||
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""
|
||||
Add a new experience to the buffer (simplified version of push for compatibility)
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
next_state: Next state
|
||||
done: Whether the episode is done
|
||||
"""
|
||||
# Store in replay buffer with max priority
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(None)
|
||||
self.buffer[self.position] = (state, action, reward, next_state, done)
|
||||
|
||||
# Set priority to max priority to ensure it gets sampled
|
||||
self.priorities[self.position] = self.max_priority
|
||||
|
||||
# Move position pointer
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
# Store experience in n-step buffer
|
||||
self.n_step_buffer.append((state, action, reward, next_state, done))
|
||||
|
||||
# If we don't have enough experiences for n-step return, wait
|
||||
if len(self.n_step_buffer) < self.n_step and not done:
|
||||
return
|
||||
|
||||
# Calculate n-step return
|
||||
reward_n = 0
|
||||
for i in range(self.n_step):
|
||||
if i >= len(self.n_step_buffer):
|
||||
break
|
||||
reward_n += self.gamma**i * self.n_step_buffer[i][2]
|
||||
|
||||
# Get state, action from the first experience
|
||||
state = self.n_step_buffer[0][0]
|
||||
action = self.n_step_buffer[0][1]
|
||||
|
||||
# Get next_state, done from the last experience
|
||||
next_state = self.n_step_buffer[-1][3]
|
||||
done = self.n_step_buffer[-1][4]
|
||||
|
||||
# Store in replay buffer with max priority
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(None)
|
||||
self.buffer[self.position] = (state, action, reward_n, next_state, done)
|
||||
|
||||
# Set priority to max priority to ensure it gets sampled
|
||||
self.priorities[self.position] = self.max_priority
|
||||
|
||||
# Move position pointer
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
# Remove the first experience from n-step buffer
|
||||
self.n_step_buffer.pop(0)
|
||||
|
||||
# If episode is done, clear n-step buffer
|
||||
if done:
|
||||
self.n_step_buffer = []
|
||||
|
||||
def sample(self, batch_size):
|
||||
# Calculate sampling probabilities
|
||||
if len(self.buffer) < self.capacity:
|
||||
probs = self.priorities[:len(self.buffer)]
|
||||
else:
|
||||
probs = self.priorities
|
||||
|
||||
# Normalize probabilities
|
||||
probs = probs ** self.alpha
|
||||
probs = probs / probs.sum()
|
||||
|
||||
# Sample indices based on priorities
|
||||
indices = torch.multinomial(probs, batch_size, replacement=True)
|
||||
|
||||
# Get samples
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
|
||||
# Calculate importance sampling weights
|
||||
weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
|
||||
weights = weights / weights.max()
|
||||
self.beta = min(1.0, self.beta + self.beta_increment) # Anneal beta
|
||||
|
||||
# Get experiences
|
||||
for idx in indices:
|
||||
state, action, reward, next_state, done = self.buffer[idx]
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(done)
|
||||
|
||||
# Return only the states, actions, rewards, next_states, dones for compatibility with learn function
|
||||
return states, actions, rewards, next_states, dones
|
||||
|
||||
def update_priorities(self, indices, td_errors):
|
||||
for idx, td_error in zip(indices, td_errors):
|
||||
# Update priority based on TD error
|
||||
priority = float(abs(td_error) + 1e-5) # Small constant to ensure non-zero priority
|
||||
self.priorities[idx] = priority
|
||||
self.max_priority = max(self.max_priority, priority)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
def train_price_predictor(model, data_loaders, optimizer, device, epochs=10):
|
||||
"""
|
||||
Train the price prediction model using data from multiple timeframes
|
||||
|
||||
Args:
|
||||
model: The EnhancedPricePredictionModel
|
||||
data_loaders: List of DataLoader objects, one for each timeframe
|
||||
optimizer: Optimizer for training
|
||||
device: Device to train on (CPU or GPU)
|
||||
epochs: Number of training epochs
|
||||
"""
|
||||
model.train()
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
# Assume all dataloaders have the same length
|
||||
for batch_idx, batch_data in enumerate(zip(*data_loaders)):
|
||||
# Each batch_data is a tuple of (inputs, price_targets, extrema_targets, volume_targets) for each timeframe
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Prepare inputs for each timeframe
|
||||
inputs_list = [data[0].to(device) for data in batch_data]
|
||||
price_targets = batch_data[0][1].to(device) # Use targets from the first timeframe (e.g., 1m)
|
||||
extrema_targets = batch_data[0][2].to(device)
|
||||
volume_targets = batch_data[0][3].to(device)
|
||||
|
||||
# Forward pass
|
||||
price_pred, extrema_logits, volume_pred = model(inputs_list)
|
||||
|
||||
# Calculate losses
|
||||
price_loss = F.mse_loss(price_pred, price_targets)
|
||||
extrema_loss = F.binary_cross_entropy_with_logits(extrema_logits, extrema_targets)
|
||||
volume_loss = F.mse_loss(volume_pred, volume_targets)
|
||||
|
||||
# Combined loss with weighting
|
||||
loss = price_loss + 0.5 * extrema_loss + 0.3 * volume_loss
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping to prevent exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}")
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.6f}")
|
||||
|
||||
# Learning rate scheduling
|
||||
if epoch > 0 and epoch % 5 == 0:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] *= 0.9
|
||||
|
||||
return model
|
||||
|
||||
def prepare_multi_timeframe_data(exchange, timeframes=['1m', '15m', '1h'], lookback=30):
|
||||
"""
|
||||
Prepare data from multiple timeframes for training
|
||||
|
||||
Args:
|
||||
exchange: Exchange object to fetch data from
|
||||
timeframes: List of timeframes to fetch
|
||||
lookback: Number of candles to look back
|
||||
|
||||
Returns:
|
||||
List of DataLoader objects, one for each timeframe
|
||||
"""
|
||||
data_loaders = []
|
||||
|
||||
for timeframe in timeframes:
|
||||
# Fetch historical data for this timeframe
|
||||
candles = exchange.fetch_ohlcv(timeframe=timeframe, limit=1000)
|
||||
|
||||
# Prepare inputs and targets
|
||||
inputs = []
|
||||
price_targets = []
|
||||
extrema_targets = []
|
||||
volume_targets = []
|
||||
|
||||
for i in range(lookback, len(candles) - 5):
|
||||
# Input: lookback candles (price and volume)
|
||||
input_data = torch.tensor([
|
||||
[candle[4], candle[5]] for candle in candles[i-lookback:i]
|
||||
], dtype=torch.float32)
|
||||
|
||||
# Target: next 5 candles (price)
|
||||
price_target = torch.tensor([
|
||||
candle[4] for candle in candles[i:i+5]
|
||||
], dtype=torch.float32)
|
||||
|
||||
# Target: extrema points in next 5 candles
|
||||
extrema_target = torch.zeros(10, dtype=torch.float32) # 5 time steps, 2 classes each
|
||||
for j in range(5):
|
||||
# Simple extrema detection for demonstration
|
||||
if j > 0 and j < 4:
|
||||
# Local high
|
||||
if candles[i+j][2] > candles[i+j-1][2] and candles[i+j][2] > candles[i+j+1][2]:
|
||||
extrema_target[j*2] = 1.0
|
||||
# Local low
|
||||
if candles[i+j][3] < candles[i+j-1][3] and candles[i+j][3] < candles[i+j+1][3]:
|
||||
extrema_target[j*2+1] = 1.0
|
||||
|
||||
# Target: volume for next 5 candles
|
||||
volume_target = torch.tensor([
|
||||
candle[5] for candle in candles[i:i+5]
|
||||
], dtype=torch.float32)
|
||||
|
||||
inputs.append(input_data)
|
||||
price_targets.append(price_target)
|
||||
extrema_targets.append(extrema_target)
|
||||
volume_targets.append(volume_target)
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = torch.utils.data.TensorDataset(
|
||||
torch.stack(inputs),
|
||||
torch.stack(price_targets),
|
||||
torch.stack(extrema_targets),
|
||||
torch.stack(volume_targets)
|
||||
)
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=32, shuffle=True
|
||||
)
|
||||
|
||||
data_loaders.append(data_loader)
|
||||
|
||||
return data_loaders
|
File diff suppressed because it is too large
Load Diff
@ -1,806 +0,0 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class ExchangeSimulator:
|
||||
"""
|
||||
A simple exchange simulator that generates realistic market data
|
||||
for testing trading algorithms without connecting to a real exchange.
|
||||
"""
|
||||
|
||||
def __init__(self, symbol="BTC/USDT", seed=42):
|
||||
"""
|
||||
Initialize the exchange simulator
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
self.symbol = symbol
|
||||
self.seed = seed
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Initialize data storage
|
||||
self.data = {}
|
||||
self.current_timestamp = datetime.now()
|
||||
|
||||
# Generate initial data for different timeframes
|
||||
self.timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
|
||||
self.timeframe_minutes = {
|
||||
'1m': 1,
|
||||
'5m': 5,
|
||||
'15m': 15,
|
||||
'30m': 30,
|
||||
'1h': 60,
|
||||
'4h': 240,
|
||||
'1d': 1440
|
||||
}
|
||||
|
||||
# Generate initial price around $50,000 (for BTC/USDT)
|
||||
self.base_price = 50000.0
|
||||
|
||||
# Generate data for each timeframe
|
||||
for tf in self.timeframes:
|
||||
self._generate_initial_data(tf)
|
||||
|
||||
def _generate_initial_data(self, timeframe, num_candles=1000):
|
||||
"""
|
||||
Generate initial historical data for a specific timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe to generate data for
|
||||
num_candles: Number of candles to generate
|
||||
"""
|
||||
# Calculate time delta for this timeframe
|
||||
minutes = self.timeframe_minutes[timeframe]
|
||||
|
||||
# Generate timestamps
|
||||
end_time = self.current_timestamp
|
||||
timestamps = [end_time - timedelta(minutes=minutes * i) for i in range(num_candles)]
|
||||
timestamps.reverse() # Oldest first
|
||||
|
||||
# Generate price data with realistic patterns
|
||||
prices = self._generate_price_series(num_candles)
|
||||
|
||||
# Generate volume data with realistic patterns
|
||||
volumes = self._generate_volume_series(num_candles, timeframe)
|
||||
|
||||
# Create OHLCV data
|
||||
ohlcv_data = []
|
||||
for i in range(num_candles):
|
||||
# Calculate OHLC based on close price
|
||||
close = prices[i]
|
||||
high = close * (1 + np.random.uniform(0, 0.01))
|
||||
low = close * (1 - np.random.uniform(0, 0.01))
|
||||
open_price = prices[i-1] if i > 0 else close * (1 - np.random.uniform(-0.005, 0.005))
|
||||
|
||||
# Create candle
|
||||
candle = [
|
||||
int(timestamps[i].timestamp() * 1000), # Timestamp in milliseconds
|
||||
open_price, # Open
|
||||
high, # High
|
||||
low, # Low
|
||||
close, # Close
|
||||
volumes[i] # Volume
|
||||
]
|
||||
ohlcv_data.append(candle)
|
||||
|
||||
# Store data
|
||||
self.data[timeframe] = ohlcv_data
|
||||
|
||||
def _generate_price_series(self, length):
|
||||
"""
|
||||
Generate a realistic price series with trends, reversals, and volatility
|
||||
|
||||
Args:
|
||||
length: Number of prices to generate
|
||||
|
||||
Returns:
|
||||
List of prices
|
||||
"""
|
||||
# Start with base price
|
||||
prices = [self.base_price]
|
||||
|
||||
# Parameters for price generation
|
||||
trend_strength = 0.001 # Strength of trend
|
||||
volatility = 0.005 # Daily volatility
|
||||
mean_reversion = 0.001 # Mean reversion strength
|
||||
|
||||
# Generate price series
|
||||
for i in range(1, length):
|
||||
# Determine if we're in a trend
|
||||
if i % 100 == 0:
|
||||
# Change trend direction every ~100 candles
|
||||
trend_strength = -trend_strength
|
||||
|
||||
# Calculate price change
|
||||
trend = trend_strength * prices[-1]
|
||||
random_change = np.random.normal(0, volatility) * prices[-1]
|
||||
mean_reversion_change = mean_reversion * (self.base_price - prices[-1])
|
||||
|
||||
# Calculate new price
|
||||
new_price = prices[-1] + trend + random_change + mean_reversion_change
|
||||
|
||||
# Ensure price doesn't go negative
|
||||
new_price = max(new_price, prices[-1] * 0.9)
|
||||
|
||||
prices.append(new_price)
|
||||
|
||||
return prices
|
||||
|
||||
def _generate_volume_series(self, length, timeframe):
|
||||
"""
|
||||
Generate a realistic volume series with patterns
|
||||
|
||||
Args:
|
||||
length: Number of volumes to generate
|
||||
timeframe: Timeframe for volume scaling
|
||||
|
||||
Returns:
|
||||
List of volumes
|
||||
"""
|
||||
# Base volume depends on timeframe
|
||||
base_volume = {
|
||||
'1m': 10,
|
||||
'5m': 50,
|
||||
'15m': 150,
|
||||
'30m': 300,
|
||||
'1h': 600,
|
||||
'4h': 2400,
|
||||
'1d': 10000
|
||||
}[timeframe]
|
||||
|
||||
# Generate volume series
|
||||
volumes = []
|
||||
for i in range(length):
|
||||
# Volume tends to be higher at trend reversals and during volatile periods
|
||||
cycle_factor = 1 + 0.5 * np.sin(i / 20) # Cyclical pattern
|
||||
random_factor = np.random.lognormal(0, 0.5) # Random spikes
|
||||
|
||||
# Calculate volume
|
||||
volume = base_volume * cycle_factor * random_factor
|
||||
|
||||
# Add some volume spikes
|
||||
if random.random() < 0.05: # 5% chance of volume spike
|
||||
volume *= random.uniform(2, 5)
|
||||
|
||||
volumes.append(volume)
|
||||
|
||||
return volumes
|
||||
|
||||
def fetch_ohlcv(self, timeframe='1m', limit=100, since=None):
|
||||
"""
|
||||
Fetch OHLCV data for a specific timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe to fetch data for
|
||||
limit: Number of candles to fetch
|
||||
since: Timestamp to fetch data since (not used in simulator)
|
||||
|
||||
Returns:
|
||||
List of OHLCV candles
|
||||
"""
|
||||
# Ensure timeframe exists
|
||||
if timeframe not in self.data:
|
||||
if timeframe in self.timeframe_minutes:
|
||||
self._generate_initial_data(timeframe)
|
||||
else:
|
||||
# Default to 1m if timeframe not supported
|
||||
timeframe = '1m'
|
||||
|
||||
# Get data
|
||||
data = self.data[timeframe]
|
||||
|
||||
# Return limited data
|
||||
return data[-limit:]
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update the exchange data by generating a new candle for each timeframe
|
||||
"""
|
||||
# Update current timestamp
|
||||
self.current_timestamp = datetime.now()
|
||||
|
||||
# Update each timeframe
|
||||
for tf in self.timeframes:
|
||||
self._add_new_candle(tf)
|
||||
|
||||
def _add_new_candle(self, timeframe):
|
||||
"""
|
||||
Add a new candle to the specified timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe to add candle to
|
||||
"""
|
||||
# Get existing data
|
||||
data = self.data[timeframe]
|
||||
|
||||
# Get last close price
|
||||
last_close = data[-1][4]
|
||||
|
||||
# Calculate time delta for this timeframe
|
||||
minutes = self.timeframe_minutes[timeframe]
|
||||
|
||||
# Calculate new timestamp
|
||||
new_timestamp = int((data[-1][0] / 1000 + minutes * 60) * 1000)
|
||||
|
||||
# Generate new price with some randomness
|
||||
price_change = np.random.normal(0, 0.002) * last_close
|
||||
new_close = last_close + price_change
|
||||
|
||||
# Calculate OHLC
|
||||
new_open = last_close
|
||||
new_high = max(new_open, new_close) * (1 + np.random.uniform(0, 0.005))
|
||||
new_low = min(new_open, new_close) * (1 - np.random.uniform(0, 0.005))
|
||||
|
||||
# Generate volume
|
||||
base_volume = data[-1][5]
|
||||
volume_change = np.random.normal(0, 0.2) * base_volume
|
||||
new_volume = max(base_volume + volume_change, base_volume * 0.5)
|
||||
|
||||
# Create new candle
|
||||
new_candle = [
|
||||
new_timestamp,
|
||||
new_open,
|
||||
new_high,
|
||||
new_low,
|
||||
new_close,
|
||||
new_volume
|
||||
]
|
||||
|
||||
# Add to data
|
||||
self.data[timeframe].append(new_candle)
|
||||
|
||||
def get_ticker(self, symbol=None):
|
||||
"""
|
||||
Get current ticker information
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get ticker for (defaults to initialized symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with ticker information
|
||||
"""
|
||||
if symbol is None:
|
||||
symbol = self.symbol
|
||||
|
||||
# Get latest 1m candle
|
||||
latest_candle = self.data['1m'][-1]
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': latest_candle[4] * 0.9999, # Slightly below last price
|
||||
'ask': latest_candle[4] * 1.0001, # Slightly above last price
|
||||
'last': latest_candle[4],
|
||||
'high': latest_candle[2],
|
||||
'low': latest_candle[3],
|
||||
'volume': latest_candle[5],
|
||||
'timestamp': latest_candle[0]
|
||||
}
|
||||
|
||||
def create_order(self, symbol, type, side, amount, price=None):
|
||||
"""
|
||||
Simulate creating an order
|
||||
|
||||
Args:
|
||||
symbol: Symbol to create order for
|
||||
type: Order type (limit, market)
|
||||
side: Order side (buy, sell)
|
||||
amount: Order amount
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
Dictionary with order information
|
||||
"""
|
||||
# Get current ticker
|
||||
ticker = self.get_ticker(symbol)
|
||||
|
||||
# Determine execution price
|
||||
if type == 'market':
|
||||
if side == 'buy':
|
||||
execution_price = ticker['ask']
|
||||
else:
|
||||
execution_price = ticker['bid']
|
||||
else: # limit order
|
||||
execution_price = price
|
||||
|
||||
# Create order object
|
||||
order = {
|
||||
'id': f"order_{int(datetime.now().timestamp() * 1000)}",
|
||||
'symbol': symbol,
|
||||
'type': type,
|
||||
'side': side,
|
||||
'amount': amount,
|
||||
'price': execution_price,
|
||||
'cost': amount * execution_price,
|
||||
'filled': amount,
|
||||
'status': 'closed',
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
return order
|
||||
|
||||
def fetch_balance(self):
|
||||
"""
|
||||
Fetch account balance (simulated)
|
||||
|
||||
Returns:
|
||||
Dictionary with balance information
|
||||
"""
|
||||
return {
|
||||
'total': {
|
||||
'USD': 10000.0,
|
||||
'BTC': 1.0
|
||||
},
|
||||
'free': {
|
||||
'USD': 5000.0,
|
||||
'BTC': 0.5
|
||||
},
|
||||
'used': {
|
||||
'USD': 5000.0,
|
||||
'BTC': 0.5
|
||||
}
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the exchange simulator to its initial state
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Reset timestamp
|
||||
self.current_timestamp = datetime.now()
|
||||
|
||||
# Regenerate data for each timeframe
|
||||
for tf in self.timeframes:
|
||||
self._generate_initial_data(tf)
|
||||
|
||||
# Reset any internal state
|
||||
self.position = 'flat'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Reset prediction history if it exists
|
||||
if hasattr(self, 'prediction_history'):
|
||||
self.prediction_history = []
|
||||
|
||||
return self
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment by executing an action
|
||||
|
||||
Args:
|
||||
action: Action to take (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT)
|
||||
|
||||
Returns:
|
||||
next_state: Next state after taking action
|
||||
reward: Reward received
|
||||
done: Whether episode is done
|
||||
info: Additional information
|
||||
"""
|
||||
# Get current price
|
||||
current_price = self.data['1m'][-1][4]
|
||||
|
||||
# Initialize info dictionary
|
||||
info = {
|
||||
'price': current_price,
|
||||
'timestamp': self.data['1m'][-1][0],
|
||||
'trade': None
|
||||
}
|
||||
|
||||
# Process action
|
||||
if action == 0: # HOLD
|
||||
pass # No action needed
|
||||
|
||||
elif action == 1: # BUY/LONG
|
||||
if self.position == 'flat':
|
||||
# Open a new long position
|
||||
self.position = 'long'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 0.99 # 1% stop loss
|
||||
self.take_profit = current_price * 1.02 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': current_price,
|
||||
'entry_time': self.data['1m'][-1][0],
|
||||
'size': self.position_size,
|
||||
'stop_loss': self.stop_loss,
|
||||
'take_profit': self.take_profit
|
||||
}
|
||||
|
||||
elif self.position == 'short':
|
||||
# Close short position and open long
|
||||
pnl = self.entry_price - current_price
|
||||
pnl_percent = pnl / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': current_price,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Open new long position
|
||||
self.position = 'long'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 0.99 # 1% stop loss
|
||||
self.take_profit = current_price * 1.02 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
elif action == 2: # SELL/SHORT
|
||||
if self.position == 'flat':
|
||||
# Open a new short position
|
||||
self.position = 'short'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 1.01 # 1% stop loss
|
||||
self.take_profit = current_price * 0.98 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': current_price,
|
||||
'entry_time': self.data['1m'][-1][0],
|
||||
'size': self.position_size,
|
||||
'stop_loss': self.stop_loss,
|
||||
'take_profit': self.take_profit
|
||||
}
|
||||
|
||||
elif self.position == 'long':
|
||||
# Close long position and open short
|
||||
pnl = current_price - self.entry_price
|
||||
pnl_percent = pnl / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': current_price,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Open new short position
|
||||
self.position = 'short'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 1.01 # 1% stop loss
|
||||
self.take_profit = current_price * 0.98 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
# Generate next candle
|
||||
self._add_new_candle('1m')
|
||||
|
||||
# Check if stop loss or take profit has been hit
|
||||
self._check_sl_tp(info)
|
||||
|
||||
# Validate predictions if available
|
||||
if hasattr(self, 'prediction_history') and len(self.prediction_history) > 0:
|
||||
self.validate_predictions(self.data['1m'][-1])
|
||||
|
||||
# Prepare next state (simplified)
|
||||
next_state = self._get_state()
|
||||
|
||||
# Calculate reward (simplified)
|
||||
reward = 0
|
||||
if info['trade'] is not None and 'pnl_dollar' in info['trade']:
|
||||
reward = info['trade']['pnl_dollar']
|
||||
|
||||
# Check if done (simplified)
|
||||
done = False
|
||||
|
||||
return next_state, reward, done, info
|
||||
|
||||
def _get_state(self):
|
||||
"""
|
||||
Get the current state of the environment
|
||||
|
||||
Returns:
|
||||
List representing the current state
|
||||
"""
|
||||
# Simplified state representation
|
||||
state = []
|
||||
|
||||
# Add price features
|
||||
for tf in ['1m', '5m', '15m']:
|
||||
if tf in self.data:
|
||||
# Get last 10 candles
|
||||
candles = self.data[tf][-10:]
|
||||
|
||||
# Extract close prices
|
||||
prices = [candle[4] for candle in candles]
|
||||
|
||||
# Calculate price changes
|
||||
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||
|
||||
# Add to state
|
||||
state.extend(price_changes)
|
||||
|
||||
# Add current price relative to SMA
|
||||
sma_5 = sum(prices[-5:]) / 5
|
||||
sma_10 = sum(prices) / 10
|
||||
state.append(prices[-1] / sma_5 - 1)
|
||||
state.append(prices[-1] / sma_10 - 1)
|
||||
|
||||
# Pad state to 100 dimensions
|
||||
while len(state) < 100:
|
||||
state.append(0)
|
||||
|
||||
# Ensure state has exactly 100 dimensions
|
||||
if len(state) > 100:
|
||||
state = state[:100]
|
||||
|
||||
return state
|
||||
|
||||
def _check_sl_tp(self, info):
|
||||
"""
|
||||
Check if stop loss or take profit has been hit
|
||||
|
||||
Args:
|
||||
info: Info dictionary to update
|
||||
"""
|
||||
if self.position == 'flat':
|
||||
return
|
||||
|
||||
# Get current price
|
||||
current_price = self.data['1m'][-1][4]
|
||||
|
||||
if self.position == 'long':
|
||||
# Check stop loss
|
||||
if current_price <= self.stop_loss:
|
||||
# Stop loss hit
|
||||
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.stop_loss,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'stop_loss',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Check take profit
|
||||
elif current_price >= self.take_profit:
|
||||
# Take profit hit
|
||||
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.take_profit,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'take_profit',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
elif self.position == 'short':
|
||||
# Check stop loss
|
||||
if current_price >= self.stop_loss:
|
||||
# Stop loss hit
|
||||
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.stop_loss,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'stop_loss',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Check take profit
|
||||
elif current_price <= self.take_profit:
|
||||
# Take profit hit
|
||||
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.take_profit,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'take_profit',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
def validate_predictions(self, new_candle):
|
||||
"""
|
||||
Validate previous extrema predictions against new candle data
|
||||
|
||||
Args:
|
||||
new_candle: New candle data to validate against
|
||||
"""
|
||||
if not hasattr(self, 'prediction_history') or not self.prediction_history:
|
||||
return
|
||||
|
||||
# Extract candle data
|
||||
timestamp = new_candle[0]
|
||||
high_price = new_candle[2]
|
||||
low_price = new_candle[3]
|
||||
|
||||
# Track validation metrics
|
||||
validated_count = 0
|
||||
correct_count = 0
|
||||
|
||||
# Check each prediction that hasn't been validated yet
|
||||
for pred in self.prediction_history:
|
||||
if pred.get('validated', False):
|
||||
continue
|
||||
|
||||
# Check if this prediction's time has come (or passed)
|
||||
if 'predicted_timestamp' in pred and timestamp >= pred['predicted_timestamp']:
|
||||
pred['validated'] = True
|
||||
validated_count += 1
|
||||
|
||||
# Check if prediction was correct
|
||||
if pred['type'] == 'low':
|
||||
# A low prediction is correct if price went within 0.5% of predicted low
|
||||
price_diff_percent = abs(low_price - pred['price']) / pred['price'] * 100
|
||||
pred['actual_price'] = low_price
|
||||
pred['price_diff_percent'] = price_diff_percent
|
||||
|
||||
# Consider correct if within 0.5% or price went lower than predicted
|
||||
was_correct = price_diff_percent < 0.5 or low_price <= pred['price']
|
||||
pred['was_correct'] = was_correct
|
||||
|
||||
if was_correct:
|
||||
correct_count += 1
|
||||
|
||||
elif pred['type'] == 'high':
|
||||
# A high prediction is correct if price went within 0.5% of predicted high
|
||||
price_diff_percent = abs(high_price - pred['price']) / pred['price'] * 100
|
||||
pred['actual_price'] = high_price
|
||||
pred['price_diff_percent'] = price_diff_percent
|
||||
|
||||
# Consider correct if within 0.5% or price went higher than predicted
|
||||
was_correct = price_diff_percent < 0.5 or high_price >= pred['price']
|
||||
pred['was_correct'] = was_correct
|
||||
|
||||
if was_correct:
|
||||
correct_count += 1
|
||||
|
||||
# Return validation metrics
|
||||
if validated_count > 0:
|
||||
return {
|
||||
'validated_count': validated_count,
|
||||
'correct_count': correct_count,
|
||||
'accuracy': correct_count / validated_count
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def calculate_pnl(self):
|
||||
"""
|
||||
Calculate the current profit/loss of the open position
|
||||
|
||||
Returns:
|
||||
float: Current PnL in dollars, 0 if no position is open
|
||||
"""
|
||||
if self.position == 'flat':
|
||||
return 0.0
|
||||
|
||||
current_price = self.data['1m'][-1][4]
|
||||
|
||||
if self.position == 'long':
|
||||
pnl_percent = (current_price - self.entry_price) / self.entry_price * 100
|
||||
elif self.position == 'short':
|
||||
pnl_percent = (self.entry_price - current_price) / self.entry_price * 100
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
return pnl_dollar
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Fetch some data
|
||||
ohlcv = exchange.fetch_ohlcv(timeframe='1h', limit=10)
|
||||
print("OHLCV data (1h timeframe):")
|
||||
for candle in ohlcv[-5:]:
|
||||
timestamp = datetime.fromtimestamp(candle[0] / 1000)
|
||||
print(f"{timestamp}: Open={candle[1]:.2f}, High={candle[2]:.2f}, Low={candle[3]:.2f}, Close={candle[4]:.2f}, Volume={candle[5]:.2f}")
|
||||
|
||||
# Get current ticker
|
||||
ticker = exchange.get_ticker()
|
||||
print(f"\nCurrent ticker: {ticker['last']:.2f}")
|
||||
|
||||
# Create a market buy order
|
||||
order = exchange.create_order("BTC/USDT", "market", "buy", 0.1)
|
||||
print(f"\nCreated order: {order}")
|
||||
|
||||
# Update the exchange (simulate time passing)
|
||||
exchange.update()
|
||||
|
||||
# Get updated ticker
|
||||
updated_ticker = exchange.get_ticker()
|
||||
print(f"\nUpdated ticker: {updated_ticker['last']:.2f}")
|
@ -1,42 +0,0 @@
|
||||
def fix_indentation():
|
||||
with open('main.py', 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Fix indentation for the problematic sections
|
||||
fixed_lines = []
|
||||
|
||||
# Find the try block that starts at line 1693
|
||||
try_start_line = 1693
|
||||
try_block_found = False
|
||||
in_try_block = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# Check if we're at the try statement
|
||||
if i+1 == try_start_line and 'try:' in line:
|
||||
try_block_found = True
|
||||
in_try_block = True
|
||||
fixed_lines.append(line)
|
||||
# Fix the indentation of the experiences line
|
||||
elif i+1 == 1695 and line.strip().startswith('experiences = self.memory.sample(BATCH_SIZE)'):
|
||||
# Add proper indentation (4 spaces)
|
||||
fixed_lines.append(' ' + line.lstrip())
|
||||
# Check if we're at the end of the try block without an except
|
||||
elif try_block_found and in_try_block and i+1 > try_start_line and line.strip() and not line.startswith(' '):
|
||||
# We've reached the end of the try block without an except, add one
|
||||
fixed_lines.append(' except Exception as e:\n')
|
||||
fixed_lines.append(' logger.error(f"Error during learning: {e}")\n')
|
||||
fixed_lines.append(' logger.error(f"Traceback: {traceback.format_exc()}")\n')
|
||||
fixed_lines.append(' return None\n\n')
|
||||
in_try_block = False
|
||||
fixed_lines.append(line)
|
||||
else:
|
||||
fixed_lines.append(line)
|
||||
|
||||
# Write the fixed content back to the file
|
||||
with open('main.py', 'w') as f:
|
||||
f.writelines(fixed_lines)
|
||||
|
||||
print("Indentation fixed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix_indentation()
|
@ -1,22 +0,0 @@
|
||||
import re
|
||||
|
||||
def fix_try_blocks():
|
||||
with open('main.py', 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find all try blocks without except or finally
|
||||
pattern = r'(\s+)try:\s*\n((?:\1\s+.*\n)+?)(?!\1\s*except|\1\s*finally)'
|
||||
|
||||
# Replace with try-except blocks
|
||||
fixed_content = re.sub(pattern,
|
||||
r'\1try:\n\2\1except Exception as e:\n\1 logger.error(f"Error: {e}")\n\1 logger.error(f"Traceback: {traceback.format_exc()}")\n\1 return None\n\n',
|
||||
content)
|
||||
|
||||
# Write the fixed content back to the file
|
||||
with open('main.py', 'w') as f:
|
||||
f.write(fixed_content)
|
||||
|
||||
print("Try blocks fixed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix_try_blocks()
|
2794
crypto/gogo2/main.py
Normal file
2794
crypto/gogo2/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,24 +0,0 @@
|
||||
import asyncio
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run the training process.
|
||||
"""
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Train agent
|
||||
print("Starting training process...")
|
||||
# Add your training code here
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
File diff suppressed because it is too large
Load Diff
@ -1,30 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run the training process.
|
||||
"""
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Train agent
|
||||
print("Starting training process...")
|
||||
# Add your training code here
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running main: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
@ -1,304 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from mexc_api.spot import Spot
|
||||
from mexc_api.common.enums import Side, OrderType
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('mexc_trading')
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
|
||||
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
class MexcTradingClient:
|
||||
"""Client for executing trades on MEXC exchange using the official API"""
|
||||
|
||||
def __init__(self, api_key=None, api_secret=None, symbol="ETH/USDT", leverage=1):
|
||||
"""Initialize the MEXC trading client"""
|
||||
self.api_key = api_key or MEXC_API_KEY
|
||||
self.api_secret = api_secret or MEXC_SECRET_KEY
|
||||
|
||||
# Ensure API keys are not None
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.warning("API keys not provided. Using empty strings for public endpoints only.")
|
||||
self.api_key = ""
|
||||
self.api_secret = ""
|
||||
|
||||
self.symbol = symbol
|
||||
self.formatted_symbol = symbol.replace('/', '') # MEXC requires no slash
|
||||
self.leverage = leverage
|
||||
self.client = None
|
||||
self.position = 'flat' # 'flat', 'long', or 'short'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.open_orders = []
|
||||
self.order_history = []
|
||||
self.trades = []
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
|
||||
# Initialize the MEXC API client
|
||||
self.initialize_client()
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the MEXC API client using the API"""
|
||||
try:
|
||||
self.client = Spot(self.api_key, self.api_secret)
|
||||
# Test connection
|
||||
server_time = self.client.market.server_time()
|
||||
logger.info(f"MEXC API client initialized successfully. Server time: {server_time}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MEXC API client: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def fetch_account_balance(self):
|
||||
"""Fetch account balance from MEXC API"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot fetch account balance.")
|
||||
return 0
|
||||
|
||||
account_info = self.client.account.account_info()
|
||||
if 'balances' in account_info:
|
||||
# Find USDT balance
|
||||
for asset in account_info['balances']:
|
||||
if asset['asset'] == 'USDT':
|
||||
return float(asset['free'])
|
||||
|
||||
logger.warning("Could not find USDT balance")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching account balance: {e}")
|
||||
return 0
|
||||
|
||||
async def fetch_open_positions(self):
|
||||
"""Fetch open positions from MEXC API"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot fetch open positions.")
|
||||
return []
|
||||
|
||||
# Fetch open orders
|
||||
open_orders = self.client.account.open_orders(self.formatted_symbol)
|
||||
return open_orders
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching open positions: {e}")
|
||||
return []
|
||||
|
||||
async def open_position(self, position_type, size, entry_price, stop_loss, take_profit):
|
||||
"""Open a new position using MEXC API"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot open position.")
|
||||
return False
|
||||
|
||||
# Calculate quantity based on size and price
|
||||
quantity = size / entry_price
|
||||
# Round quantity to appropriate precision
|
||||
quantity = round(quantity, 4) # Adjust precision as needed for your asset
|
||||
|
||||
# Determine order side
|
||||
side = Side.BUY if position_type == 'long' else Side.SELL
|
||||
|
||||
logger.info(f"Opening {position_type} position: {quantity} {self.symbol} at market price")
|
||||
|
||||
# Place market order
|
||||
order_result = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
side,
|
||||
OrderType.MARKET,
|
||||
str(quantity)
|
||||
)
|
||||
|
||||
logger.info(f"Market order result: {order_result}")
|
||||
|
||||
# Check if order was filled
|
||||
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
|
||||
# Get actual entry price
|
||||
actual_entry_price = float(order_result.get('price', entry_price))
|
||||
|
||||
# Place stop loss order
|
||||
sl_side = Side.SELL if position_type == 'long' else Side.BUY
|
||||
sl_order = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
sl_side,
|
||||
OrderType.STOP_LOSS_LIMIT,
|
||||
str(quantity),
|
||||
price=str(stop_loss),
|
||||
stop_price=str(stop_loss),
|
||||
time_in_force="GTC"
|
||||
)
|
||||
|
||||
logger.info(f"Stop loss order placed: {sl_order}")
|
||||
|
||||
# Place take profit order
|
||||
tp_side = Side.SELL if position_type == 'long' else Side.BUY
|
||||
tp_order = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
tp_side,
|
||||
OrderType.TAKE_PROFIT_LIMIT,
|
||||
str(quantity),
|
||||
price=str(take_profit),
|
||||
stop_price=str(take_profit),
|
||||
time_in_force="GTC"
|
||||
)
|
||||
|
||||
logger.info(f"Take profit order placed: {tp_order}")
|
||||
|
||||
# Update local state
|
||||
self.position = position_type
|
||||
self.position_size = size
|
||||
self.entry_price = actual_entry_price
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
# Track orders
|
||||
self.open_orders.extend([sl_order, tp_order])
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"Successfully opened {position_type} position at {actual_entry_price}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to open position: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def close_position(self, reason="manual_close"):
|
||||
"""Close an existing position"""
|
||||
if self.position == 'flat':
|
||||
logger.info("No position to close")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot close position.")
|
||||
return False
|
||||
|
||||
# First, cancel any existing stop loss/take profit orders
|
||||
try:
|
||||
self.client.account.cancel_open_orders(self.formatted_symbol)
|
||||
logger.info("Canceled all open orders")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error canceling open orders: {e}")
|
||||
|
||||
# Determine order side (opposite of position)
|
||||
side = Side.SELL if self.position == 'long' else Side.BUY
|
||||
|
||||
# Calculate quantity
|
||||
quantity = self.position_size / self.entry_price
|
||||
# Round quantity to appropriate precision
|
||||
quantity = round(quantity, 4) # Adjust precision as needed
|
||||
|
||||
logger.info(f"Closing {self.position} position: {quantity} {self.symbol} at market price")
|
||||
|
||||
# Execute market order to close position
|
||||
order_result = self.client.account.new_order(
|
||||
self.formatted_symbol,
|
||||
side,
|
||||
OrderType.MARKET,
|
||||
str(quantity)
|
||||
)
|
||||
|
||||
logger.info(f"Close order result: {order_result}")
|
||||
|
||||
# Check if order was filled
|
||||
if order_result.get('status') == 'FILLED' or order_result.get('status') == 'PARTIALLY_FILLED':
|
||||
# Get actual exit price
|
||||
exit_price = float(order_result.get('price', 0))
|
||||
|
||||
# Calculate PnL
|
||||
if self.position == 'long':
|
||||
pnl_percent = (exit_price - self.entry_price) / self.entry_price * 100
|
||||
else: # short
|
||||
pnl_percent = (self.entry_price - exit_price) / self.entry_price * 100
|
||||
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': self.position,
|
||||
'entry': self.entry_price,
|
||||
'exit': exit_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': reason,
|
||||
'order_id': order_result.get('orderId')
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Track order history
|
||||
self.order_history.append(order_result)
|
||||
|
||||
logger.info(f"Closed {self.position} position at {exit_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to close position: {order_result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing position: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def check_order_status(self, order_id):
|
||||
"""Check the status of a specific order"""
|
||||
try:
|
||||
# Check if we have API keys for private endpoints
|
||||
if not self.api_key or self.api_key == "":
|
||||
logger.warning("No API keys provided. Cannot check order status.")
|
||||
return None
|
||||
|
||||
order_status = self.client.account.query_order(
|
||||
self.formatted_symbol,
|
||||
order_id=order_id
|
||||
)
|
||||
return order_status
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking order status: {e}")
|
||||
return None
|
||||
|
||||
async def get_market_price(self):
|
||||
"""Get current market price for the symbol"""
|
||||
try:
|
||||
ticker = self.client.market.ticker_price(self.formatted_symbol)
|
||||
if isinstance(ticker, list) and len(ticker) > 0 and 'price' in ticker[0]:
|
||||
return float(ticker[0]['price'])
|
||||
elif isinstance(ticker, dict) and 'price' in ticker:
|
||||
return float(ticker['price'])
|
||||
else:
|
||||
logger.error(f"Unexpected ticker format: {ticker}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market price: {e}")
|
||||
return None
|
@ -1,456 +0,0 @@
|
||||
# mexc-api-sdk
|
||||
|
||||
MEXC Official Market and trade api sdk, easy to connection and send request to MEXC open api !
|
||||
|
||||
## Prerequisites
|
||||
- To use our SDK you have to install nodejs LTS (https://aws.github.io/jsii/user-guides/lib-user/)
|
||||
|
||||
## Installation
|
||||
1.
|
||||
```
|
||||
git clone https://github.com/mxcdevelop/mexc-api-sdk.git
|
||||
```
|
||||
2. cd dist/{language} and unzip the file
|
||||
3. we offer five language : dotnet, go, java, js, python
|
||||
|
||||
## Table of APIS
|
||||
- [Init](#init)
|
||||
- [Market](#market)
|
||||
- [Ping](#ping)
|
||||
- [Check Server Time](#check-server-time)
|
||||
- [Exchange Information](#exchange-information)
|
||||
- [Recent Trades List](#recent-trades-list)
|
||||
- [Order Book](#order-book)
|
||||
- [Old Trade Lookup](#old-trade-lookup)
|
||||
- [Aggregate Trades List](#aggregate-trades-list)
|
||||
- [kline Data](#kline-data)
|
||||
- [Current Average Price](#current-average-price)
|
||||
- [24hr Ticker Price Change Statistics](#24hr-ticker-price-change-statistics)
|
||||
- [Symbol Price Ticker](#symbol-price-ticker)
|
||||
- [Symbol Order Book Ticker](#symbol-order-book-ticker)
|
||||
- [Trade](#trade)
|
||||
- [Test New Order](#test-new-order)
|
||||
- [New Order](#new-order)
|
||||
- [cancel-order](#cancel-order)
|
||||
- [Cancel all Open Orders on a Symbol](#cancel-all-open-orders-on-a-symbol)
|
||||
- [Query Order](#query-order)
|
||||
- [Current Open Orders](#current-open-orders)
|
||||
- [All Orders](#all-orders)
|
||||
- [Account Information](#account-information)
|
||||
- [Account Trade List](#account-trade-list)
|
||||
## Init
|
||||
```javascript
|
||||
//Javascript
|
||||
import * as Mexc from 'mexc-sdk';
|
||||
const apiKey = 'apiKey'
|
||||
const apiSecret = 'apiSecret'
|
||||
const client = new Mexc.Spot(apiKey, apiSecret);
|
||||
```
|
||||
```go
|
||||
// Go
|
||||
package main
|
||||
import (
|
||||
"fmt"
|
||||
"mexc-sdk/mexcsdk"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiKey := "apiKey"
|
||||
apiSecret := "apiSecret"
|
||||
spot := mexcsdk.NewSpot(apiKey, apiSecret)
|
||||
}
|
||||
```
|
||||
```python
|
||||
# python
|
||||
from mexc_sdk import Spot
|
||||
spot = Spot(api_key='apiKey', api_secret='apiSecret')
|
||||
```
|
||||
```java
|
||||
// java
|
||||
import Mexc.Sdk.*;
|
||||
class MyClass {
|
||||
public static void main(String[] args) {
|
||||
String apiKey= "apiKey";
|
||||
String apiSecret= "apiSecret";
|
||||
Spot mySpot = new Spot(apiKey, apiSecret);
|
||||
}
|
||||
}
|
||||
```
|
||||
```C#
|
||||
// dotnet
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Mxc.Sdk;
|
||||
|
||||
namespace dotnet
|
||||
{
|
||||
class Program
|
||||
{
|
||||
static void Main(string[] args)
|
||||
{
|
||||
string apiKey = "apiKey";
|
||||
string apiSecret= "apiSecret";
|
||||
var spot = new Spot(apiKey, apiSecret);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Market
|
||||
### Ping
|
||||
```javascript
|
||||
client.ping()
|
||||
```
|
||||
### Check Server Time
|
||||
```javascript
|
||||
client.time()
|
||||
```
|
||||
### Exchange Information
|
||||
```javascript
|
||||
client.exchangeInfo(options: any)
|
||||
options:{symbol, symbols}
|
||||
/**
|
||||
* choose one parameter
|
||||
*
|
||||
* symbol :
|
||||
* example "BNBBTC";
|
||||
*
|
||||
* symbols :
|
||||
* array of symbol
|
||||
* example ["BTCUSDT","BNBBTC"];
|
||||
*
|
||||
*/
|
||||
```
|
||||
|
||||
### Recent Trades List
|
||||
```javascript
|
||||
client.trades(symbol: string, options: any = { limit: 500 })
|
||||
options:{limit}
|
||||
/**
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
### Order Book
|
||||
```javascript
|
||||
client.depth(symbol: string, options: any = { limit: 100 })
|
||||
options:{limit}
|
||||
/**
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 100;
|
||||
* max 5000;
|
||||
* Valid:[5, 10, 20, 50, 100, 500, 1000, 5000]
|
||||
*
|
||||
*/
|
||||
```
|
||||
|
||||
### Old Trade Lookup
|
||||
```javascript
|
||||
client.historicalTrades(symbol: string, options: any = { limit: 500 })
|
||||
options:{limit, fromId}
|
||||
/**
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
* fromId:
|
||||
* Trade id to fetch from. Default gets most recent trades
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### Aggregate Trades List
|
||||
```javascript
|
||||
client.aggTrades(symbol: string, options: any = { limit: 500 })
|
||||
options:{fromId, startTime, endTime, limit}
|
||||
/**
|
||||
*
|
||||
* fromId :
|
||||
* id to get aggregate trades from INCLUSIVE
|
||||
*
|
||||
* startTime:
|
||||
* start at
|
||||
*
|
||||
* endTime:
|
||||
* end at
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
### kline Data
|
||||
```javascript
|
||||
client.klines(symbol: string, interval: string, options: any = { limit: 500 })
|
||||
options:{ startTime, endTime, limit}
|
||||
/**
|
||||
*
|
||||
* interval :
|
||||
* m :minute;
|
||||
* h :Hour;
|
||||
* d :day;
|
||||
* w :week;
|
||||
* M :month
|
||||
* example : "1m"
|
||||
*
|
||||
* startTime :
|
||||
* start at
|
||||
*
|
||||
* endTime :
|
||||
* end at
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
|
||||
### Current Average Price
|
||||
```javascript
|
||||
client.avgPrice(symbol: string)
|
||||
```
|
||||
### 24hr Ticker Price Change Statistics
|
||||
```javascript
|
||||
client.ticker24hr(symbol?: string)
|
||||
```
|
||||
### Symbol Price Ticker
|
||||
```javascript
|
||||
client.tickerPrice(symbol?: string)
|
||||
```
|
||||
|
||||
### Symbol Order Book Ticker
|
||||
```javascript
|
||||
client.bookTicker(symbol?: string)
|
||||
```
|
||||
## Trade
|
||||
### Test New Order
|
||||
```javascript
|
||||
client.newOrderTest(symbol: string, side: string, orderType: string, options: any = {})
|
||||
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
|
||||
/**
|
||||
*
|
||||
* side:
|
||||
* Order side
|
||||
* ENUM:
|
||||
* BUY
|
||||
* SELL
|
||||
*
|
||||
* orderType:
|
||||
* Order type
|
||||
* ENUM:
|
||||
* LIMIT
|
||||
* MARKET
|
||||
* STOP_LOSS
|
||||
* STOP_LOSS_LIMIT
|
||||
* TAKE_PROFIT
|
||||
* TAKE_PROFIT_LIMIT
|
||||
* LIMIT_MAKER
|
||||
*
|
||||
* timeInForce :
|
||||
* How long an order will be active before expiration.
|
||||
* GTC: Active unless the order is canceled
|
||||
* IOC: Order will try to fill the order as much as it can before the order expires
|
||||
* FOK: Active unless the full order cannot be filled upon execution.
|
||||
*
|
||||
* quantity :
|
||||
* target quantity
|
||||
*
|
||||
* quoteOrderQty :
|
||||
* Specify the total spent or received
|
||||
*
|
||||
* price :
|
||||
* target price
|
||||
*
|
||||
* newClientOrderId :
|
||||
* A unique id among open orders. Automatically generated if not sent
|
||||
*
|
||||
* stopPrice :
|
||||
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
|
||||
*
|
||||
* icebergQty :
|
||||
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
|
||||
*
|
||||
* newOrderRespType :
|
||||
* Set the response JSON. ACK, RESULT, or FULL;
|
||||
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
|
||||
*
|
||||
* recvWindow :
|
||||
* Delay accept time
|
||||
* The value cannot be greater than 60000
|
||||
* defaults: 5000
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### New Order
|
||||
```javascript
|
||||
client.newOrder(symbol: string, side: string, orderType: string, options: any = {})
|
||||
options:{ timeInForce, quantity, quoteOrderQty, price, newClientOrderId, stopPrice, icebergQty, newOrderRespType, recvWindow}
|
||||
/**
|
||||
*
|
||||
* side:
|
||||
* Order side
|
||||
* ENUM:
|
||||
* BUY
|
||||
* SELL
|
||||
*
|
||||
* orderType:
|
||||
* Order type
|
||||
* ENUM:
|
||||
* LIMIT
|
||||
* MARKET
|
||||
* STOP_LOSS
|
||||
* STOP_LOSS_LIMIT
|
||||
* TAKE_PROFIT
|
||||
* TAKE_PROFIT_LIMIT
|
||||
* LIMIT_MAKER
|
||||
*
|
||||
* timeInForce :
|
||||
* How long an order will be active before expiration.
|
||||
* GTC: Active unless the order is canceled
|
||||
* IOC: Order will try to fill the order as much as it can before the order expires
|
||||
* FOK: Active unless the full order cannot be filled upon execution.
|
||||
*
|
||||
* quantity :
|
||||
* target quantity
|
||||
*
|
||||
* quoteOrderQty :
|
||||
* Specify the total spent or received
|
||||
*
|
||||
* price :
|
||||
* target price
|
||||
*
|
||||
* newClientOrderId :
|
||||
* A unique id among open orders. Automatically generated if not sent
|
||||
*
|
||||
* stopPrice :
|
||||
* sed with STOP_LOSS, STOP_LOSS_LIMIT, TAKE_PROFIT, and TAKE_PROFIT_LIMIT orders
|
||||
*
|
||||
* icebergQty :
|
||||
* Used with LIMIT, STOP_LOSS_LIMIT, and TAKE_PROFIT_LIMIT to create an iceberg order
|
||||
*
|
||||
* newOrderRespType :
|
||||
* Set the response JSON. ACK, RESULT, or FULL;
|
||||
* MARKET and LIMIT order types default to FULL, all other orders default to ACK
|
||||
*
|
||||
* recvWindow :
|
||||
* Delay accept time
|
||||
* The value cannot be greater than 60000
|
||||
* defaults: 5000
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### cancel-order
|
||||
```javascript
|
||||
client.cancelOrder(symbol: string, options:any = {})
|
||||
options:{ orderId, origClientOrderId, newClientOrderId}
|
||||
/**
|
||||
*
|
||||
* Either orderId or origClientOrderId must be sent
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* origClientOrderId:
|
||||
* target origClientOrderId
|
||||
*
|
||||
* newClientOrderId:
|
||||
* Used to uniquely identify this cancel. Automatically generated by default.
|
||||
*
|
||||
*/
|
||||
|
||||
```
|
||||
|
||||
### Cancel all Open Orders on a Symbol
|
||||
```javascript
|
||||
client.cancelOpenOrders(symbol: string)
|
||||
```
|
||||
### Query Order
|
||||
```javascript
|
||||
client.queryOrder(symbol: string, options:any = {})
|
||||
options:{ orderId, origClientOrderId}
|
||||
/**
|
||||
*
|
||||
* Either orderId or origClientOrderId must be sent
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* origClientOrderId:
|
||||
* target origClientOrderId
|
||||
*
|
||||
*/
|
||||
```
|
||||
### Current Open Orders
|
||||
```javascript
|
||||
client.openOrders(symbol: string)
|
||||
```
|
||||
### All Orders
|
||||
```javascript
|
||||
client.allOrders(symbol: string, options: any = { limit: 500 })
|
||||
options:{ orderId, startTime, endTime, limit}
|
||||
|
||||
/**
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* startTime:
|
||||
* start at
|
||||
*
|
||||
* endTime:
|
||||
* end at
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
||||
### Account Information
|
||||
```javascript
|
||||
client.accountInfo()
|
||||
```
|
||||
### Account Trade List
|
||||
```javascript
|
||||
client.accountTradeList(symbol: string, options:any = { limit: 500 })
|
||||
options:{ orderId, startTime, endTime, fromId, limit}
|
||||
|
||||
/**
|
||||
*
|
||||
* orderId:
|
||||
* target orderId
|
||||
*
|
||||
* startTime:
|
||||
* start at
|
||||
*
|
||||
* endTime:
|
||||
* end at
|
||||
*
|
||||
* fromId:
|
||||
* TradeId to fetch from. Default gets most recent trades
|
||||
*
|
||||
* limit :
|
||||
* Number of returned data
|
||||
* Default 500;
|
||||
* max 1000;
|
||||
*
|
||||
*/
|
||||
```
|
@ -1,91 +1,188 @@
|
||||
# Crypto Trading Bot with MEXC API Integration
|
||||
# Crypto Trading Bot with Reinforcement Learning
|
||||
|
||||
This is an AI-powered cryptocurrency trading bot that can run in both simulation (demo) mode and live trading mode using the MEXC exchange API.
|
||||
An automated cryptocurrency trading bot that uses Deep Q-Learning (DQN) to trade ETH/USDT on the MEXC exchange. The bot features a sophisticated neural network architecture with LSTM layers and attention mechanisms for better pattern recognition.
|
||||
|
||||
## Features
|
||||
|
||||
- Deep Reinforcement Learning agent for trading decisions
|
||||
- Technical indicators and price prediction
|
||||
- Live trading integration with MEXC exchange via mexc-api
|
||||
- Demo mode for testing without real trades
|
||||
- Real-time data streaming via websockets
|
||||
- Performance tracking and visualization
|
||||
- Deep Q-Learning with experience replay
|
||||
- LSTM layers for sequential data processing
|
||||
- Multi-head attention mechanism
|
||||
- Dueling DQN architecture
|
||||
- Real-time trading capabilities
|
||||
- TensorBoard integration for monitoring
|
||||
- Comprehensive technical indicators
|
||||
- Demo and live trading modes
|
||||
- Automatic model checkpointing
|
||||
|
||||
## Setup
|
||||
## Prerequisites
|
||||
|
||||
1. Clone the repository
|
||||
2. Install dependencies:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
3. Create a `.env` file in the root directory with your MEXC API keys:
|
||||
```
|
||||
MEXC_API_KEY=your_api_key_here
|
||||
MEXC_SECRET_KEY=your_secret_key_here
|
||||
```
|
||||
- Python 3.8+
|
||||
- MEXC Exchange API credentials
|
||||
- GPU recommended but not required
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yourusername/crypto-trading-bot.git
|
||||
cd crypto-trading-bot
|
||||
```
|
||||
2. Create a virtual environment:
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
3. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
4. Create a `.env` file in the project root with your MEXC API credentials:
|
||||
|
||||
```bash
|
||||
MEXC_API_KEY=your_api_key
|
||||
MEXC_API_SECRET=your_api_secret
|
||||
|
||||
|
||||
cuda support
|
||||
|
||||
```bash
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
## Usage
|
||||
|
||||
The bot can be run in three different modes:
|
||||
The bot can be run in three modes:
|
||||
|
||||
### Training Mode
|
||||
|
||||
Train the agent on historical data:
|
||||
|
||||
```
|
||||
python main.py --mode train --episodes 100
|
||||
```bash
|
||||
python main.py --mode train --episodes 1000
|
||||
```
|
||||
|
||||
### Evaluation Mode
|
||||
|
||||
Evaluate the trained agent on historical data:
|
||||
|
||||
```
|
||||
python main.py --mode evaluate
|
||||
```bash
|
||||
python main.py --mode eval --episodes 10
|
||||
```
|
||||
|
||||
### Live Trading Mode
|
||||
|
||||
Run the bot in live trading mode:
|
||||
```bash
|
||||
# Demo mode (simulated trading with real market data)
|
||||
python main.py --mode live --demo
|
||||
|
||||
```
|
||||
# Real trading (actual trades on MEXC)
|
||||
python main.py --mode live
|
||||
```
|
||||
|
||||
To run in demo mode (no real trades):
|
||||
Demo mode simulates trading using real-time market data but does not execute actual trades. It still:
|
||||
- Logs all trading decisions and performance metrics
|
||||
- Updates the model based on market data (if in training mode)
|
||||
- Displays real-time analytics and position information
|
||||
- Calculates theoretical profits/losses
|
||||
- Saves performance data to TensorBoard
|
||||
|
||||
```
|
||||
python main.py --mode live --demo
|
||||
```
|
||||
|
||||
## Live Trading Implementation
|
||||
|
||||
The bot uses the mexc-api package to execute trades on the MEXC exchange. The implementation includes:
|
||||
|
||||
- Market order execution for opening and closing positions
|
||||
- Stop loss and take profit orders
|
||||
- Real-time balance updates
|
||||
- Trade history tracking
|
||||
This makes it perfect for testing strategies without financial risk.
|
||||
|
||||
## Configuration
|
||||
|
||||
You can adjust the following parameters in `main.py`:
|
||||
Key parameters can be adjusted in `main.py`:
|
||||
|
||||
- `INITIAL_BALANCE`: Starting balance for simulation
|
||||
- `MAX_LEVERAGE`: Leverage to use for trading
|
||||
- `STOP_LOSS_PERCENT`: Default stop loss percentage
|
||||
- `TAKE_PROFIT_PERCENT`: Default take profit percentage
|
||||
- `INITIAL_BALANCE`: Starting balance for training/demo
|
||||
- `MAX_LEVERAGE`: Maximum leverage for trades
|
||||
- `STOP_LOSS_PERCENT`: Stop loss percentage
|
||||
- `TAKE_PROFIT_PERCENT`: Take profit percentage
|
||||
- `BATCH_SIZE`: Training batch size
|
||||
- `LEARNING_RATE`: Model learning rate
|
||||
- `STATE_SIZE`: Size of the state representation
|
||||
|
||||
## Architecture
|
||||
## Model Architecture
|
||||
|
||||
The DQN model includes:
|
||||
- Input layer with technical indicators
|
||||
- LSTM layers for temporal pattern recognition
|
||||
- Multi-head attention mechanism
|
||||
- Dueling architecture for better Q-value estimation
|
||||
- Batch normalization for stable training
|
||||
|
||||
## Monitoring
|
||||
|
||||
Training progress can be monitored using TensorBoard:
|
||||
|
||||
|
||||
Training progress is logged to TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir=logs
|
||||
```
|
||||
|
||||
This will show:
|
||||
- Training rewards
|
||||
- Account balance
|
||||
- Win rate
|
||||
- Loss metrics
|
||||
|
||||
## Trading Strategy
|
||||
|
||||
The bot makes decisions based on:
|
||||
- Price action
|
||||
- Technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Historical patterns through LSTM
|
||||
- Risk management with stop-loss and take-profit
|
||||
|
||||
## Safety Features
|
||||
|
||||
- Demo mode for safe testing
|
||||
- Automatic stop-loss
|
||||
- Position size limits
|
||||
- Error handling for API calls
|
||||
- Logging of all actions
|
||||
|
||||
## Directory Structure
|
||||
├── main.py # Main bot implementation
|
||||
├── requirements.txt # Project dependencies
|
||||
├── .env # API credentials
|
||||
├── models/ # Saved model checkpoints
|
||||
├── runs/ # TensorBoard logs
|
||||
└── trading_bot.log # Activity logs
|
||||
|
||||
- `main.py`: Main entry point and trading logic
|
||||
- `mexc_trading.py`: MEXC API integration for live trading using mexc-api
|
||||
- `models/`: Directory for saved model weights
|
||||
|
||||
## Warning
|
||||
|
||||
Trading cryptocurrencies involves significant risk. This bot is provided for educational purposes only. Use at your own risk.
|
||||
Cryptocurrency trading carries significant risks. This bot is for educational purposes and should not be used with real money without thorough testing and understanding of the risks involved.
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
[MIT License](LICENSE)
|
||||
|
||||
The main changes I made:
|
||||
Fixed code block formatting by adding proper language identifiers
|
||||
Added missing closing code blocks
|
||||
Properly formatted directory structure
|
||||
Added complete sections that were cut off in the original
|
||||
Ensured consistent formatting throughout the document
|
||||
Added proper bash syntax highlighting for command examples
|
||||
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Edits/improvements
|
||||
|
||||
Fixes the shape mismatch by ensuring the state vector is exactly STATE_SIZE elements
|
||||
Adds robust error handling in the model's forward pass to handle mismatched inputs
|
||||
Adds a transformer encoder for more sophisticated pattern recognition
|
||||
Provides an expand_model method to increase model capacity while preserving learned weights
|
||||
Adds detailed logging about model size and shape mismatches
|
||||
The model now has:
|
||||
Configurable hidden layer sizes
|
||||
Transformer layers for complex pattern recognition
|
||||
LSTM layers for temporal patterns
|
||||
Attention mechanisms for focusing on important features
|
||||
Dueling architecture for better Q-value estimation
|
||||
With hidden_size=256, this model has about 1-2 million parameters. By increasing hidden_size to 512 or 1024, you can easily scale to 5-20 million parameters. For even larger models (billions of parameters), you would need to implement a more distributed architecture with multiple GPUs, which would require significant changes to the training loop.
|
||||
|
@ -1,11 +1,10 @@
|
||||
numpy>=1.20.0
|
||||
numpy>=1.21.0
|
||||
pandas>=1.3.0
|
||||
matplotlib>=3.4.0
|
||||
torch>=1.9.0
|
||||
scikit-learn>=0.24.0
|
||||
ccxt>=2.0.0
|
||||
python-dotenv>=0.19.0
|
||||
ccxt>=2.0.0
|
||||
websockets>=10.0
|
||||
tensorboard>=2.7.0
|
||||
mexc-api>=1.0.0
|
||||
asyncio>=3.4.3
|
||||
tensorboard>=2.6.0
|
||||
scikit-learn
|
||||
mplfinance
|
@ -1,316 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import torch
|
||||
from enhanced_training import enhanced_train_agent
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
|
||||
# Fix for Windows asyncio
|
||||
if sys.platform == 'win32':
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
|
||||
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'continuous', 'evaluate', 'live', 'demo'],
|
||||
help='Mode to run the trading bot in')
|
||||
|
||||
parser.add_argument('--episodes', type=int, default=100,
|
||||
help='Number of episodes to train for')
|
||||
|
||||
parser.add_argument('--start-episode', type=int, default=0,
|
||||
help='Episode to start from for continuous training')
|
||||
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='Device to train on (auto, cuda, cpu)')
|
||||
|
||||
parser.add_argument('--timeframes', type=str, default='1m,15m,1h',
|
||||
help='Comma-separated list of timeframes to use')
|
||||
|
||||
parser.add_argument('--refresh-data', action='store_true',
|
||||
help='Refresh data before training')
|
||||
|
||||
parser.add_argument('--verbose', action='store_true',
|
||||
help='Enable verbose logging')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Parse timeframes
|
||||
timeframes = args.timeframes.split(',')
|
||||
print(f"Using timeframes: {timeframes}")
|
||||
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Run in specified mode
|
||||
if args.mode == 'train':
|
||||
# Train from scratch
|
||||
print(f"Training for {args.episodes} episodes...")
|
||||
enhanced_train_agent(
|
||||
exchange=exchange,
|
||||
num_episodes=args.episodes,
|
||||
continuous=False,
|
||||
start_episode=0,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
elif args.mode == 'continuous':
|
||||
# Continue training from checkpoint
|
||||
print(f"Continuing training from episode {args.start_episode} for {args.episodes} episodes...")
|
||||
enhanced_train_agent(
|
||||
exchange=exchange,
|
||||
num_episodes=args.episodes,
|
||||
continuous=True,
|
||||
start_episode=args.start_episode,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
# Evaluate the model
|
||||
print("Evaluating model...")
|
||||
evaluate_model(exchange, device)
|
||||
|
||||
elif args.mode == 'live' or args.mode == 'demo':
|
||||
# Run in live or demo mode
|
||||
is_demo = args.mode == 'demo'
|
||||
print(f"Running in {'demo' if is_demo else 'live'} mode...")
|
||||
run_live(exchange, device, is_demo=is_demo)
|
||||
|
||||
print("Done!")
|
||||
|
||||
def evaluate_model(exchange, device):
|
||||
"""
|
||||
Evaluate the trained model
|
||||
|
||||
Args:
|
||||
exchange: Exchange simulator
|
||||
device: Device to run on
|
||||
"""
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Load the best model
|
||||
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
|
||||
if not os.path.exists(model_path):
|
||||
model_path = 'models/enhanced_trading_agent_latest.pt'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print("No model found to evaluate!")
|
||||
return
|
||||
|
||||
print(f"Loading model from {model_path}")
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
|
||||
# Initialize models
|
||||
state_dim = 100
|
||||
action_dim = 3
|
||||
timeframes = ['1m', '15m', '1h']
|
||||
|
||||
price_model = EnhancedPricePredictionModel(
|
||||
input_dim=2,
|
||||
hidden_dim=256,
|
||||
num_layers=3,
|
||||
output_dim=5,
|
||||
num_timeframes=len(timeframes)
|
||||
).to(device)
|
||||
|
||||
dqn_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(device)
|
||||
|
||||
# Load model weights
|
||||
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
|
||||
# Set models to evaluation mode
|
||||
price_model.eval()
|
||||
dqn_model.eval()
|
||||
|
||||
# Run evaluation
|
||||
num_steps = 1000
|
||||
total_reward = 0
|
||||
trades = []
|
||||
|
||||
# Initialize state
|
||||
from enhanced_training import initialize_state, step_environment
|
||||
state = initialize_state(exchange, timeframes)
|
||||
|
||||
for step in range(num_steps):
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
q_values, _, _ = dqn_model(state_tensor)
|
||||
action = q_values.argmax().item()
|
||||
|
||||
# Execute action
|
||||
next_state, reward, done, trade_info = step_environment(
|
||||
exchange, state, action, price_model, timeframes, device
|
||||
)
|
||||
|
||||
# Update state and accumulate reward
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
|
||||
# Track trade
|
||||
if trade_info is not None:
|
||||
trades.append(trade_info)
|
||||
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, PnL: {trade_info['pnl']:.2f}")
|
||||
|
||||
# Update exchange (simulate time passing)
|
||||
if step % 10 == 0:
|
||||
exchange.update()
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Calculate metrics
|
||||
avg_reward = total_reward / num_steps
|
||||
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
||||
wins = sum(1 for trade in trades if trade['pnl'] > 0)
|
||||
losses = sum(1 for trade in trades if trade['pnl'] < 0)
|
||||
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||
|
||||
print("\nEvaluation Results:")
|
||||
print(f"Average Reward: {avg_reward:.2f}")
|
||||
print(f"Total PnL: ${total_pnl:.2f}")
|
||||
print(f"Win Rate: {win_rate:.1f}% ({wins}/{wins+losses})")
|
||||
|
||||
def run_live(exchange, device, is_demo=True):
|
||||
"""
|
||||
Run the trading bot in live or demo mode
|
||||
|
||||
Args:
|
||||
exchange: Exchange simulator or real exchange
|
||||
device: Device to run on
|
||||
is_demo: Whether to run in demo mode (no real trades)
|
||||
"""
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
|
||||
import torch
|
||||
import time
|
||||
|
||||
# Load the best model
|
||||
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
|
||||
if not os.path.exists(model_path):
|
||||
model_path = 'models/enhanced_trading_agent_latest.pt'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print("No model found to run in live mode!")
|
||||
return
|
||||
|
||||
print(f"Loading model from {model_path}")
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
|
||||
# Initialize models
|
||||
state_dim = 100
|
||||
action_dim = 3
|
||||
timeframes = ['1m', '15m', '1h']
|
||||
|
||||
price_model = EnhancedPricePredictionModel(
|
||||
input_dim=2,
|
||||
hidden_dim=256,
|
||||
num_layers=3,
|
||||
output_dim=5,
|
||||
num_timeframes=len(timeframes)
|
||||
).to(device)
|
||||
|
||||
dqn_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(device)
|
||||
|
||||
# Load model weights
|
||||
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
|
||||
# Set models to evaluation mode
|
||||
price_model.eval()
|
||||
dqn_model.eval()
|
||||
|
||||
# Run live trading
|
||||
print(f"Running in {'demo' if is_demo else 'live'} mode...")
|
||||
print("Press Ctrl+C to stop")
|
||||
|
||||
# Initialize state
|
||||
from enhanced_training import initialize_state, step_environment
|
||||
state = initialize_state(exchange, timeframes)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
q_values, _, market_regime = dqn_model(state_tensor)
|
||||
action = q_values.argmax().item()
|
||||
|
||||
# Get market regime prediction
|
||||
regime_probs = torch.softmax(market_regime, dim=1).cpu().numpy()[0]
|
||||
regime_names = ['Trending', 'Ranging', 'Volatile']
|
||||
predicted_regime = regime_names[regime_probs.argmax()]
|
||||
|
||||
# Get current price
|
||||
ticker = exchange.get_ticker()
|
||||
current_price = ticker['last']
|
||||
|
||||
# Print state
|
||||
print(f"\nCurrent price: ${current_price:.2f}")
|
||||
print(f"Predicted market regime: {predicted_regime} ({regime_probs.max()*100:.1f}% confidence)")
|
||||
|
||||
# Execute action
|
||||
next_state, reward, _, trade_info = step_environment(
|
||||
exchange, state, action, price_model, timeframes, device
|
||||
)
|
||||
|
||||
# Print action
|
||||
action_names = ['Hold', 'Buy', 'Sell']
|
||||
print(f"Action: {action_names[action]}")
|
||||
|
||||
if trade_info is not None:
|
||||
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, Size: {trade_info['size']:.2f}, Entry Quality: {trade_info['entry_quality']:.2f}")
|
||||
|
||||
# Execute real trade if not in demo mode
|
||||
if not is_demo:
|
||||
if trade_info['action'] == 'buy':
|
||||
order = exchange.create_order(
|
||||
symbol="BTC/USDT",
|
||||
type="market",
|
||||
side="buy",
|
||||
amount=trade_info['size'] / current_price
|
||||
)
|
||||
print(f"Executed buy order: {order}")
|
||||
else: # sell
|
||||
order = exchange.create_order(
|
||||
symbol="BTC/USDT",
|
||||
type="market",
|
||||
side="sell",
|
||||
amount=trade_info['size'] / current_price
|
||||
)
|
||||
print(f"Executed sell order: {order}")
|
||||
|
||||
# Update state
|
||||
state = next_state
|
||||
|
||||
# Update exchange (simulate time passing)
|
||||
exchange.update()
|
||||
|
||||
# Wait for next candle
|
||||
time.sleep(5) # In a real implementation, this would wait for the next candle
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping live trading")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,18 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from enhanced_training import main
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running main: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,185 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger('cache_test')
|
||||
|
||||
# Import our cache implementation
|
||||
from data_cache import ohlcv_cache
|
||||
|
||||
def generate_sample_data(num_candles=100):
|
||||
"""Generate sample OHLCV data for testing"""
|
||||
data = []
|
||||
base_timestamp = int(time.time() * 1000) - (num_candles * 60 * 1000) # Start from num_candles minutes ago
|
||||
|
||||
for i in range(num_candles):
|
||||
timestamp = base_timestamp + (i * 60 * 1000) # Add i minutes
|
||||
|
||||
# Generate some random-ish but realistic looking price data
|
||||
base_price = 1900.0 + (i * 0.5) # Slight uptrend
|
||||
open_price = base_price - 0.5 + (i % 3)
|
||||
close_price = base_price + 0.3 + ((i+1) % 4)
|
||||
high_price = max(open_price, close_price) + 1.0 + (i % 2)
|
||||
low_price = min(open_price, close_price) - 0.8 - (i % 2)
|
||||
volume = 10.0 + (i % 10) * 2.0
|
||||
|
||||
data.append({
|
||||
'timestamp': timestamp,
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume
|
||||
})
|
||||
|
||||
return data
|
||||
|
||||
def test_cache_save_load():
|
||||
"""Test saving and loading data from cache"""
|
||||
logger.info("Testing cache save and load...")
|
||||
|
||||
# Generate sample data
|
||||
data = generate_sample_data(100)
|
||||
logger.info(f"Generated {len(data)} sample candles")
|
||||
|
||||
# Save to cache
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "1m"
|
||||
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||
logger.info(f"Saved to cache: {success}")
|
||||
|
||||
# Load from cache
|
||||
cached_data = ohlcv_cache.load(symbol, timeframe)
|
||||
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
|
||||
|
||||
# Verify data integrity
|
||||
if cached_data:
|
||||
first_original = data[0]
|
||||
first_cached = cached_data[0]
|
||||
logger.info(f"First original candle: {first_original}")
|
||||
logger.info(f"First cached candle: {first_cached}")
|
||||
|
||||
last_original = data[-1]
|
||||
last_cached = cached_data[-1]
|
||||
logger.info(f"Last original candle: {last_original}")
|
||||
logger.info(f"Last cached candle: {last_cached}")
|
||||
|
||||
return success and cached_data and len(cached_data) == len(data)
|
||||
|
||||
def test_cache_append():
|
||||
"""Test appending a new candle to cached data"""
|
||||
logger.info("Testing cache append...")
|
||||
|
||||
# Generate sample data
|
||||
data = generate_sample_data(100)
|
||||
|
||||
# Save to cache
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "5m"
|
||||
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||
logger.info(f"Saved to cache: {success}")
|
||||
|
||||
# Generate a new candle
|
||||
last_timestamp = data[-1]['timestamp']
|
||||
new_timestamp = last_timestamp + (5 * 60 * 1000) # 5 minutes later
|
||||
new_candle = {
|
||||
'timestamp': new_timestamp,
|
||||
'open': 1950.0,
|
||||
'high': 1955.0,
|
||||
'low': 1948.0,
|
||||
'close': 1952.0,
|
||||
'volume': 15.0
|
||||
}
|
||||
|
||||
# Append to cache
|
||||
success = ohlcv_cache.append(new_candle, symbol, timeframe)
|
||||
logger.info(f"Appended to cache: {success}")
|
||||
|
||||
# Load from cache
|
||||
cached_data = ohlcv_cache.load(symbol, timeframe)
|
||||
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
|
||||
|
||||
# Verify the new candle was appended
|
||||
if cached_data:
|
||||
last_cached = cached_data[-1]
|
||||
logger.info(f"New candle: {new_candle}")
|
||||
logger.info(f"Last cached candle: {last_cached}")
|
||||
|
||||
return success and cached_data and len(cached_data) == len(data) + 1
|
||||
|
||||
def test_cache_dataframe():
|
||||
"""Test converting cached data to a pandas DataFrame"""
|
||||
logger.info("Testing cache to DataFrame conversion...")
|
||||
|
||||
# Generate sample data
|
||||
data = generate_sample_data(100)
|
||||
|
||||
# Save to cache
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "15m"
|
||||
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||
logger.info(f"Saved to cache: {success}")
|
||||
|
||||
# Convert to DataFrame
|
||||
df = ohlcv_cache.to_dataframe(symbol, timeframe)
|
||||
logger.info(f"Converted to DataFrame with {len(df) if df is not None else 0} rows")
|
||||
|
||||
# Display DataFrame info
|
||||
if df is not None:
|
||||
logger.info(f"DataFrame columns: {df.columns.tolist()}")
|
||||
logger.info(f"DataFrame index: {df.index.name}")
|
||||
logger.info(f"First row: {df.iloc[0].to_dict()}")
|
||||
logger.info(f"Last row: {df.iloc[-1].to_dict()}")
|
||||
|
||||
return success and df is not None and len(df) == len(data)
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
logger.info("Starting cache tests...")
|
||||
|
||||
# Run tests
|
||||
save_load_success = test_cache_save_load()
|
||||
append_success = test_cache_append()
|
||||
dataframe_success = test_cache_dataframe()
|
||||
|
||||
# Print results
|
||||
logger.info("Test results:")
|
||||
logger.info(f" Save/Load: {'PASS' if save_load_success else 'FAIL'}")
|
||||
logger.info(f" Append: {'PASS' if append_success else 'FAIL'}")
|
||||
logger.info(f" DataFrame: {'PASS' if dataframe_success else 'FAIL'}")
|
||||
|
||||
# Check cache directory contents
|
||||
cache_dir = ohlcv_cache.cache_dir
|
||||
logger.info(f"Cache directory: {cache_dir}")
|
||||
if os.path.exists(cache_dir):
|
||||
files = os.listdir(cache_dir)
|
||||
logger.info(f"Cache files: {files}")
|
||||
|
||||
# Print file sizes
|
||||
for file in files:
|
||||
file_path = os.path.join(cache_dir, file)
|
||||
size_kb = os.path.getsize(file_path) / 1024
|
||||
logger.info(f" {file}: {size_kb:.2f} KB")
|
||||
|
||||
# Print first few lines of each file
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
logger.info(f" Metadata: symbol={data.get('symbol')}, timeframe={data.get('timeframe')}, last_updated={datetime.fromtimestamp(data.get('last_updated')).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f" Candles: {len(data.get('data', []))}")
|
||||
|
||||
return save_load_success and append_success and dataframe_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,23 +0,0 @@
|
||||
from mexc_api.spot import Spot
|
||||
|
||||
def test_mexc_api():
|
||||
try:
|
||||
# Initialize client with empty API keys for public endpoints
|
||||
client = Spot("", "")
|
||||
|
||||
# Test server time endpoint
|
||||
server_time = client.market.server_time()
|
||||
print(f"Server time: {server_time}")
|
||||
|
||||
# Test ticker price endpoint
|
||||
ticker = client.market.ticker_price("ETHUSDT")
|
||||
print(f"ETH/USDT price: {ticker}")
|
||||
|
||||
print("MEXC API is working correctly!")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error testing MEXC API: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mexc_api()
|
@ -1,34 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from mexc_trading import MexcTradingClient
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
async def test_trading_client():
|
||||
"""Test the MexcTradingClient functionality"""
|
||||
print("Initializing MexcTradingClient...")
|
||||
client = MexcTradingClient(symbol="ETH/USDT")
|
||||
|
||||
# Test getting market price
|
||||
print("Testing get_market_price...")
|
||||
price = await client.get_market_price()
|
||||
print(f"Current ETH/USDT price: {price}")
|
||||
|
||||
# If API keys are provided, test account balance
|
||||
if os.getenv('MEXC_API_KEY') and os.getenv('MEXC_SECRET_KEY'):
|
||||
print("Testing fetch_account_balance...")
|
||||
balance = await client.fetch_account_balance()
|
||||
print(f"Account balance: {balance} USDT")
|
||||
|
||||
print("Testing fetch_open_positions...")
|
||||
positions = await client.fetch_open_positions()
|
||||
print(f"Open positions: {positions}")
|
||||
else:
|
||||
print("No API keys provided. Skipping private endpoint tests.")
|
||||
|
||||
print("MexcTradingClient test completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_trading_client())
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Before Width: | Height: | Size: 307 KiB After Width: | Height: | Size: 170 KiB |
Binary file not shown.
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 86 KiB |
@ -1,56 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import webbrowser
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Visualize TensorBoard logs')
|
||||
parser.add_argument('--logdir', type=str, default='./logs', help='Directory containing TensorBoard logs')
|
||||
parser.add_argument('--port', type=int, default=6006, help='Port for TensorBoard server')
|
||||
args = parser.parse_args()
|
||||
|
||||
log_dir = Path(args.logdir)
|
||||
|
||||
if not log_dir.exists():
|
||||
print(f"Log directory {log_dir} does not exist. Creating it...")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if TensorBoard is installed
|
||||
try:
|
||||
subprocess.run(['tensorboard', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
print("TensorBoard not found. Installing...")
|
||||
subprocess.run(['pip', 'install', 'tensorboard'], check=True)
|
||||
|
||||
# Start TensorBoard server
|
||||
print(f"Starting TensorBoard server on port {args.port}...")
|
||||
tensorboard_process = subprocess.Popen(
|
||||
['tensorboard', '--logdir', str(log_dir), '--port', str(args.port)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
|
||||
# Wait for TensorBoard to start
|
||||
time.sleep(3)
|
||||
|
||||
# Open browser
|
||||
url = f"http://localhost:{args.port}"
|
||||
print(f"Opening TensorBoard in browser: {url}")
|
||||
webbrowser.open(url)
|
||||
|
||||
print("TensorBoard is running. Press Ctrl+C to stop.")
|
||||
|
||||
try:
|
||||
# Keep the script running until interrupted
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping TensorBoard server...")
|
||||
tensorboard_process.terminate()
|
||||
tensorboard_process.wait()
|
||||
print("TensorBoard server stopped.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user