import asyncio import torch import torch.nn as nn import torch.optim as optim from data.live_data import LiveDataManager from model.transformer import Transformer from training.train import train from data.data_utils import preprocess_data # Import preprocess_data import ccxt.async_support as ccxt import time import os import numpy as np import matplotlib.pyplot as plt from model.trading_model import TradingModel from training.rl_agent import ContinuousRLAgent, ReplayBuffer from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR from data.data_utils import get_aligned_candle_with_index, get_features_for_tf import argparse async def main_training(): symbol = 'BTC/USDT' data_manager = LiveDataManager(symbol) # Model parameters (adjust for ~1B parameters) input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs d_model = 512 num_heads = 8 num_layers = 6 d_ff = 2048 dropout = 0.1 model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout) optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) # Define loss functions criterion_candle = nn.MSELoss() criterion_volume = nn.MSELoss() # Consider a different loss for volume if needed criterion_ticks = nn.MSELoss() # Check for CUDA availability and set device if torch.cuda.is_available(): device = torch.device('cuda') print("Using CUDA") else: device = torch.device('cpu') print("Using CPU") try: await train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device=device) except KeyboardInterrupt: print("Training stopped manually.") finally: await data_manager.close() # ------------------------------------- # Main Asynchronous Function for Training & Charting # ------------------------------------- async def main_backtest(): symbol = 'BTC/USDT' # Define timeframes: we'll use 5 different ones. timeframes = ["1m", "5m", "15m", "1h", "1d"] now = int(time.time() * 1000) # Use the base timeframe period of 1500 candles. For 1m, that is 1500 minutes. period_ms = 1500 * 60 * 1000 since = now - period_ms end_time = now # Initialize exchange using MEXC (or your preferred exchange). mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY') mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY') exchange = ccxt.mexc({ 'apiKey': mexc_api_key, 'secret': mexc_api_secret, 'enableRateLimit': True, }) candles_dict = {} for tf in timeframes: print(f"Fetching historical data for timeframe {tf}...") candles = await fetch_historical_data(exchange, symbol, tf, since, end_time, batch_size=500) candles_dict[tf] = candles # Optionally, save the multi-timeframe cache. save_candles_cache(CACHE_FILE, candles_dict) # Create the backtest environment using multi-timeframe data. env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) # Neural Network dimensions: each timeframe produces 7 features. input_dim = len(timeframes) * 7 # 7 features * 5 timeframes = 35. hidden_dim = 128 output_dim = 3 # Actions: SELL, HOLD, BUY. model = TradingModel(input_dim, hidden_dim, output_dim) optimizer = optim.Adam(model.parameters(), lr=1e-4) replay_buffer = ReplayBuffer(capacity=10000) rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99) # Load best checkpoint if available. load_best_checkpoint(model, BEST_DIR) # Train the agent over the historical period. num_epochs = 10 # Adjust as needed. train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1) # Run a final simulation (without exploration) to record trade history. state = env.reset(clear_trade_history=True) done = False cumulative_reward = 0.0 while not done: action = rl_agent.act(state, epsilon=0.0) state, reward, next_state, done = env.step(action) cumulative_reward += reward state = next_state print("Final simulation cumulative profit:", cumulative_reward) # Evaluate trade performance. trades = env.trade_history num_trades = len(trades) num_wins = sum(1 for trade in trades if trade["pnl"] > 0) win_rate = (num_wins / num_trades * 100) if num_trades > 0 else 0.0 total_profit = sum(trade["pnl"] for trade in trades) print(f"Total trades: {num_trades}, Wins: {num_wins}, Win rate: {win_rate:.2f}%, Total Profit: {total_profit:.4f}") # Plot chart with buy/sell markers on the base timeframe ("1m"). plot_trade_history(candles_dict["1m"], trades) await exchange.close() # ------------------------------------- # Historical Data Fetching Function (for a given timeframe) # ------------------------------------- async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): candles = [] since_ms = since while True: try: batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size) except Exception as e: print(f"Error fetching historical data for {timeframe}:", e) break if not batch: break for c in batch: candle_dict = { 'timestamp': c[0], 'open': c[1], 'high': c[2], 'low': c[3], 'close': c[4], 'volume': c[5] } candles.append(candle_dict) last_timestamp = batch[-1][0] if last_timestamp >= end_time: break since_ms = last_timestamp + 1 print(f"Fetched {len(candles)} candles for timeframe {timeframe}.") return candles # ------------------------------------- # Backtest Environment with Multi-Timeframe State # ------------------------------------- class BacktestEnvironment: def __init__(self, candles_dict, base_tf="1m", timeframes=None): self.candles_dict = candles_dict # dict of timeframe: candles_list self.base_tf = base_tf if timeframes is None: self.timeframes = [base_tf] # fallback to single timeframe else: self.timeframes = timeframes self.trade_history = [] # record of closed trades self.current_index = 0 # index on base_tf candles self.position = None # active position record def reset(self, clear_trade_history=True): self.current_index = 0 self.position = None if clear_trade_history: self.trade_history = [] return self.get_state(self.current_index) def get_state(self, index): """Construct the state as the concatenated features of all timeframes. For each timeframe, find the aligned candle for the base timeframe’s timestamp.""" state_features = [] base_candle = self.candles_dict[self.base_tf][index] base_ts = base_candle["timestamp"] for tf in self.timeframes: candles_list = self.candles_dict[tf] # Get the candle from this timeframe that is closest to (and <=) base_ts. aligned_index, _ = get_aligned_candle_with_index(candles_list, base_ts) features = get_features_for_tf(candles_list, aligned_index, period=10) state_features.extend(features) return np.array(state_features, dtype=np.float32) def step(self, action): """ Simulate a trading step based on the base timeframe. - If not in a position and action is BUY (2), record entry at next candle's open. - If in a position and action is SELL (0), record exit at next candle's open, computing PnL. Returns: (current_state, reward, next_state, done) """ base_candles = self.candles_dict[self.base_tf] if self.current_index >= len(base_candles) - 1: return self.get_state(self.current_index), 0.0, None, True current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) current_candle = base_candles[self.current_index] next_candle = base_candles[next_index] reward = 0.0 # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. if self.position is None: if action == 2: # BUY signal: enter position at next candle's open. entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: if action == 0: # SELL signal: close position at next candle's open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base_candles) - 1) return current_state, reward, next_state, done # ------------------------------------- # Chart Plotting: Trade History & PnL # ------------------------------------- def plot_trade_history(candles, trade_history): close_prices = [candle["close"] for candle in candles] x = list(range(len(close_prices))) plt.figure(figsize=(12, 6)) plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1) # Use these flags so that the label "BUY" or "SELL" is only shown once in the legend. buy_label_added = False sell_label_added = False for trade in trade_history: in_idx = trade["entry_index"] out_idx = trade["exit_index"] in_price = trade["entry_price"] out_price = trade["exit_price"] pnl = trade["pnl"] # Plot BUY marker ("IN") if not buy_label_added: plt.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY (IN)") buy_label_added = True else: plt.plot(in_idx, in_price, marker="^", color="green", markersize=10) plt.text(in_idx, in_price, " IN", color="green", fontsize=8, verticalalignment="bottom") # Plot SELL marker ("OUT") if not sell_label_added: plt.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL (OUT)") sell_label_added = True else: plt.plot(out_idx, out_price, marker="v", color="red", markersize=10) plt.text(out_idx, out_price, " OUT", color="red", fontsize=8, verticalalignment="top") # Annotate the PnL near the SELL marker. plt.text(out_idx, out_price, f" {pnl:+.2f}", color="blue", fontsize=8, verticalalignment="bottom") # Choose line color based on profitability. if pnl > 0: line_color = "green" elif pnl < 0: line_color = "red" else: line_color = "gray" # Draw a dotted line between the buy and sell points. plt.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color=line_color) plt.title("Trade History with PnL") plt.xlabel("Base Candle Index (1m)") plt.ylabel("Price") plt.legend() plt.grid(True) plt.show() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Trading Bot Modes') parser.add_argument('--mode', type=str, default='backtest', choices=['train', 'backtest'], help='Choose mode: train or backtest') args = parser.parse_args() if args.mode == 'train': asyncio.run(main_training()) elif args.mode == 'backtest': asyncio.run(main_backtest())