#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import argparse import threading import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from collections import deque from datetime import datetime import matplotlib.pyplot as plt import ccxt.async_support as ccxt from torch.nn import TransformerEncoder, TransformerEncoderLayer import math from dotenv import load_dotenv load_dotenv() # --- Directories --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # --- Constants --- NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # Example: 20 technical indicators FEATURES_PER_CHANNEL = 7 # e.g. HLOC, SMA_close, SMA_volume # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) # --- Enhanced Transformer Model --- class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, dropout=0.1, activation='gelu', batch_first=False ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [self.channel_branches[i](x[:, i, :]) for i in range(num_channels)] stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) transformer_out = self.transformer(stacked, src_mask=src_mask) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Technical Indicator Helper Functions --- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] return sum(values) / len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] return sum(values) / len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume = compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # --- Caching and Checkpoint Functions --- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[: len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: r = float(parts[1]) best_files.append((r, file)) except Exception: continue return best_files def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "reward": reward, "model_state_dict": model.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: min_reward, min_file = min(best_models, key=lambda x: x[0]) if reward > min_reward: add_to_best = True os.remove(os.path.join(best_dir, min_file)) if add_to_best: best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "reward": reward, "model_state_dict": model.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_reward, best_file = max(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint # --- Backtest Environment --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes): self.candles_dict = candles_dict # dict of timeframe: candles_list self.base_tf = base_tf self.timeframes = timeframes self.current_index = 0 self.trade_history = [] self.position = None def reset(self): self.current_index = 0 self.position = None self.trade_history = [] return self.get_state(self.current_index) def get_state(self, index): state_features = [] base_ts = self.candles_dict[self.base_tf][index]["timestamp"] for tf in self.timeframes: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): base_candles = self.candles_dict[self.base_tf] if self.current_index >= len(base_candles) - 1: return self.get_state(self.current_index), 0.0, None, True current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) current_candle = base_candles[self.current_index] next_candle = base_candles[next_index] reward = 0.0 # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. if self.position is None: if action == 2: # BUY signal entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: if action == 0: # SELL signal exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base_candles) - 1) return current_state, reward, next_state, done def __len__(self): return len(self.candles_dict[self.base_tf]) # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args): optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) for epoch in range(args.epochs): state = env.reset() total_loss = 0 model.train() while True: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) # Here we use dummy targets extracted from the next candle's high/low _, _, next_state, done, actual_high, actual_low = env.step(None) target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) high_loss = torch.abs(pred_high - target_high) * 2 low_loss = torch.abs(pred_low - target_low) * 2 loss = (high_loss + low_loss).mean() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() if done: break state = next_state scheduler.step() print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") save_checkpoint(model, epoch, total_loss) # --- Live Plotting Functions --- def update_live_chart(ax, candles, trade_history): ax.clear() close_prices = [candle["close"] for candle in candles] x = list(range(len(close_prices))) ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1) buy_label_added = False sell_label_added = False for trade in trade_history: in_idx = trade["entry_index"] out_idx = trade["exit_index"] in_price = trade["entry_price"] out_price = trade["exit_price"] if not buy_label_added: ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY") buy_label_added = True else: ax.plot(in_idx, in_price, marker="^", color="green", markersize=10) if not sell_label_added: ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL") sell_label_added = True else: ax.plot(out_idx, out_price, marker="v", color="red", markersize=10) ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue") ax.set_title("Live Trading Chart") ax.set_xlabel("Candle Index") ax.set_ylabel("Price") ax.legend() ax.grid(True) def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # Update every second # --- Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train','live','inference'], default='train') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005) return parser.parse_args() def random_action(): return random.randint(0, 2) # --- Main Function --- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define timeframes; these must match your data and expected state dimensions. timeframes = ["1m", "5m", "15m", "1h", "1d"] input_dim = len(timeframes) * 7 # 7 features per timeframe. hidden_dim = 128 output_dim = 3 # Actions: SELL, HOLD, BUY. # For the Transformer model, we set number of channels = NUM_TIMEFRAMES + NUM_INDICATORS. model = TradingModel(NUM_TIMEFRAMES + NUM_INDICATORS, NUM_TIMEFRAMES).to(device) if args.mode == 'train': candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No historical candle data available for backtesting.") return base_tf = "1m" env = BacktestEnvironment(candles_dict, base_tf, timeframes) train_on_historical_data(env, model, device, args) elif args.mode == 'live': load_best_checkpoint(model) candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No cached candles available for live preview.") return env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) # Start the live preview in a separate daemon thread. preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True) preview_thread.start() print("Starting live trading loop. (Using random actions for simulation.)") while True: state, reward, next_state, done = env.step(random_action()) if done: print("Reached end of simulated data, resetting environment.") state = env.reset(clear_trade_history=False) await asyncio.sleep(1) # Simulate one candle per second. elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") # Implement your inference loop here. else: print("Invalid mode specified.") if __name__ == "__main__": asyncio.run(main())