#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) from dotenv import load_dotenv import os import time import json import ccxt.async_support as ccxt import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque from datetime import datetime import matplotlib.pyplot as plt # --- Directories for saving models --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # ------------------------------------- # Utility functions for caching candles to file # ------------------------------------- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) # ------------------------------------- # Checkpoint Functions (same as before) # ------------------------------------- def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[: len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: r = float(parts[1]) best_files.append((r, file)) except Exception: continue return best_files def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "reward": reward, "model_state_dict": model.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: min_reward, min_file = min(best_models, key=lambda x: x[0]) if reward > min_reward: add_to_best = True os.remove(os.path.join(best_dir, min_file)) if add_to_best: best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "reward": reward, "model_state_dict": model.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_reward, best_file = max(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint # ------------------------------------- # Technical Indicator Helper Functions # ------------------------------------- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] return sum(values) / len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] return sum(values) / len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): """Find the candle in the list whose timestamp is the largest that is <= target_ts.""" best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): """Return a vector of 7 features: open, high, low, close, volume, sma_close, sma_volume.""" candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume = compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # ------------------------------------- # Neural Network Architecture Definition # ------------------------------------- class TradingModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(TradingModel, self).__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): return self.net(x) # ------------------------------------- # Replay Buffer for Experience Storage # ------------------------------------- class ReplayBuffer: def __init__(self, capacity=10000): self.buffer = deque(maxlen=capacity) def add(self, experience): self.buffer.append(experience) def sample(self, batch_size): indices = np.random.choice(len(self.buffer), size=batch_size, replace=False) return [self.buffer[i] for i in indices] def __len__(self): return len(self.buffer) # ------------------------------------- # RL Agent with Q-Learning and Epsilon-Greedy Exploration # ------------------------------------- class ContinuousRLAgent: def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99): self.model = model self.optimizer = optimizer self.replay_buffer = replay_buffer self.batch_size = batch_size self.loss_fn = nn.MSELoss() self.gamma = gamma def act(self, state, epsilon=0.1): if np.random.rand() < epsilon: return np.random.randint(0, 3) state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0) with torch.no_grad(): output = self.model(state_tensor) return torch.argmax(output, dim=1).item() def train_step(self): if len(self.replay_buffer) < self.batch_size: return batch = self.replay_buffer.sample(self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states_tensor = torch.from_numpy(np.array(states, dtype=np.float32)) actions_tensor = torch.tensor(actions, dtype=torch.int64) rewards_tensor = torch.from_numpy(np.array(rewards, dtype=np.float32)).unsqueeze(1) next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32)) dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1) Q_values = self.model(states_tensor) current_Q = Q_values.gather(1, actions_tensor.unsqueeze(1)) with torch.no_grad(): next_Q_values = self.model(next_states_tensor) max_next_Q = next_Q_values.max(1)[0].unsqueeze(1) target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor) loss = self.loss_fn(current_Q, target) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # ------------------------------------- # Historical Data Fetching Function (for a given timeframe) # ------------------------------------- async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): candles = [] since_ms = since while True: try: batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size) except Exception as e: print(f"Error fetching historical data for {timeframe}:", e) break if not batch: break for c in batch: candle_dict = { 'timestamp': c[0], 'open': c[1], 'high': c[2], 'low': c[3], 'close': c[4], 'volume': c[5] } candles.append(candle_dict) last_timestamp = batch[-1][0] if last_timestamp >= end_time: break since_ms = last_timestamp + 1 print(f"Fetched {len(candles)} candles for timeframe {timeframe}.") return candles # ------------------------------------- # Backtest Environment with Multi-Timeframe State # ------------------------------------- class BacktestEnvironment: def __init__(self, candles_dict, base_tf="1m", timeframes=None): self.candles_dict = candles_dict # dict of timeframe: candles_list self.base_tf = base_tf if timeframes is None: self.timeframes = [base_tf] # fallback to single timeframe else: self.timeframes = timeframes self.trade_history = [] # record of closed trades self.current_index = 0 # index on base_tf candles self.position = None # active position record def reset(self, clear_trade_history=True): self.current_index = 0 self.position = None if clear_trade_history: self.trade_history = [] return self.get_state(self.current_index) def get_state(self, index): """Construct the state as the concatenated features of all timeframes. For each timeframe, find the aligned candle for the base timeframe’s timestamp.""" state_features = [] base_candle = self.candles_dict[self.base_tf][index] base_ts = base_candle["timestamp"] for tf in self.timeframes: candles_list = self.candles_dict[tf] # Get the candle from this timeframe that is closest to (and <=) base_ts. aligned_index, _ = get_aligned_candle_with_index(candles_list, base_ts) features = get_features_for_tf(candles_list, aligned_index, period=10) state_features.extend(features) return np.array(state_features, dtype=np.float32) def step(self, action): """ Simulate a trading step based on the base timeframe. - If not in a position and action is BUY (2), record entry at next candle's open. - If in a position and action is SELL (0), record exit at next candle's open, computing PnL. Returns: (current_state, reward, next_state, done) """ base_candles = self.candles_dict[self.base_tf] if self.current_index >= len(base_candles) - 1: return self.get_state(self.current_index), 0.0, None, True current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) current_candle = base_candles[self.current_index] next_candle = base_candles[next_index] reward = 0.0 # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. if self.position is None: if action == 2: # BUY signal: enter position at next candle's open. entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: if action == 0: # SELL signal: close position at next candle's open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base_candles) - 1) return current_state, reward, next_state, done # ------------------------------------- # Chart Plotting: Trade History & PnL # ------------------------------------- def plot_trade_history(candles, trade_history): close_prices = [candle["close"] for candle in candles] x = list(range(len(close_prices))) plt.figure(figsize=(12, 6)) plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1) # Use these flags so that the label "BUY" or "SELL" is only shown once in the legend. buy_label_added = False sell_label_added = False for trade in trade_history: in_idx = trade["entry_index"] out_idx = trade["exit_index"] in_price = trade["entry_price"] out_price = trade["exit_price"] pnl = trade["pnl"] # Plot BUY marker ("IN") if not buy_label_added: plt.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY (IN)") buy_label_added = True else: plt.plot(in_idx, in_price, marker="^", color="green", markersize=10) plt.text(in_idx, in_price, " IN", color="green", fontsize=8, verticalalignment="bottom") # Plot SELL marker ("OUT") if not sell_label_added: plt.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL (OUT)") sell_label_added = True else: plt.plot(out_idx, out_price, marker="v", color="red", markersize=10) plt.text(out_idx, out_price, " OUT", color="red", fontsize=8, verticalalignment="top") # Annotate the PnL near the SELL marker. plt.text(out_idx, out_price, f" {pnl:+.2f}", color="blue", fontsize=8, verticalalignment="bottom") # Choose line color based on profitability. if pnl > 0: line_color = "green" elif pnl < 0: line_color = "red" else: line_color = "gray" # Draw a dotted line between the buy and sell points. plt.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color=line_color) plt.title("Trade History with PnL") plt.xlabel("Base Candle Index (1m)") plt.ylabel("Price") plt.legend() plt.grid(True) plt.show() # ------------------------------------- # Training Loop: Backtesting Trading Episodes # ------------------------------------- def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1): for epoch in range(1, num_epochs + 1): state = env.reset() # clear trade history each epoch done = False total_reward = 0.0 steps = 0 while not done: action = rl_agent.act(state, epsilon=epsilon) prev_state = state state, reward, next_state, done = env.step(action) if next_state is None: next_state = np.zeros_like(prev_state) rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done)) rl_agent.train_step() total_reward += reward steps += 1 print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.") save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR) # ------------------------------------- # Main Asynchronous Function for Training & Charting # ------------------------------------- async def main_backtest(): symbol = 'BTC/USDT' # Define timeframes: we'll use 5 different ones. timeframes = ["1m", "5m", "15m", "1h", "1d"] now = int(time.time() * 1000) # Use the base timeframe period of 1500 candles. For 1m, that is 1500 minutes. period_ms = 1500 * 60 * 1000 since = now - period_ms end_time = now # Initialize exchange using MEXC (or your preferred exchange). mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY') mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY') exchange = ccxt.mexc({ 'apiKey': mexc_api_key, 'secret': mexc_api_secret, 'enableRateLimit': True, }) candles_dict = {} for tf in timeframes: print(f"Fetching historical data for timeframe {tf}...") candles = await fetch_historical_data(exchange, symbol, tf, since, end_time, batch_size=500) candles_dict[tf] = candles # Optionally, save the multi-timeframe cache. save_candles_cache(CACHE_FILE, candles_dict) # Create the backtest environment using multi-timeframe data. env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) # Neural Network dimensions: each timeframe produces 7 features. input_dim = len(timeframes) * 7 # 7 features * 5 timeframes = 35. hidden_dim = 128 output_dim = 3 # Actions: SELL, HOLD, BUY. model = TradingModel(input_dim, hidden_dim, output_dim) optimizer = optim.Adam(model.parameters(), lr=1e-4) replay_buffer = ReplayBuffer(capacity=10000) rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99) # Load best checkpoint if available. load_best_checkpoint(model, BEST_DIR) # Train the agent over the historical period. num_epochs = 10 # Adjust as needed. train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1) # Run a final simulation (without exploration) to record trade history. state = env.reset(clear_trade_history=True) done = False cumulative_reward = 0.0 while not done: action = rl_agent.act(state, epsilon=0.0) state, reward, next_state, done = env.step(action) cumulative_reward += reward state = next_state print("Final simulation cumulative profit:", cumulative_reward) # Evaluate trade performance. trades = env.trade_history num_trades = len(trades) num_wins = sum(1 for trade in trades if trade["pnl"] > 0) win_rate = (num_wins / num_trades * 100) if num_trades > 0 else 0.0 total_profit = sum(trade["pnl"] for trade in trades) print(f"Total trades: {num_trades}, Wins: {num_wins}, Win rate: {win_rate:.2f}%, Total Profit: {total_profit:.4f}") # Plot chart with buy/sell markers on the base timeframe ("1m"). plot_trade_history(candles_dict["1m"], trades) await exchange.close() if __name__ == "__main__": load_dotenv() asyncio.run(main_backtest())