#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import argparse import threading import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime import matplotlib.pyplot as plt import math from torch.nn import TransformerEncoder, TransformerEncoderLayer import matplotlib.dates as mdates from dotenv import load_dotenv load_dotenv() # --- Directories --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # --- Constants --- NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # e.g., 20 technical indicators # Each channel input will have 7 features. FEATURES_PER_CHANNEL = 7 # We add one extra channel for order information. ORDER_CHANNELS = 1 # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) # --- Enhanced Transformer Model --- class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() # Create one branch per channel. self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, dropout=0.1, activation='gelu', batch_first=False ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): # x: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) transformer_out = self.transformer(stacked, mask=src_mask) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Technical Indicator Helpers --- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] return sum(values)/len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] return sum(values)/len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume = compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # --- Caching & Checkpoint Functions --- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[:len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: loss = float(parts[1]) best_files.append((loss, file)) except Exception: continue return best_files def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: worst_loss, worst_file = max(best_models, key=lambda x: x[0]) if loss < worst_loss: add_to_best = True os.remove(os.path.join(best_dir, worst_file)) if add_to_best: best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes every 10 seconds. """ from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) # Compute cumulative epoch PnL. epoch_pnl = sum(trade["pnl"] for trade in trade_history) ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_content = f""" Live Trading Chart - Epoch {epoch}

Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}

Live Chart
""" with open("live_chart.html", "w") as f: f.write(html_content) print("Updated live_chart.html.") # --- Chart Drawing Helpers --- def update_live_chart(ax, candles, trade_history): """ Plot the price chart with actual timestamps on the x-axis. Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit. """ ax.clear() # Convert timestamps to datetime objects. times = [datetime.fromtimestamp(candle["timestamp"]) for candle in candles] close_prices = [candle["close"] for candle in candles] ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) # Format x-axis date labels. ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) # Plot each trade. buy_label_added = False sell_label_added = False for trade in trade_history: entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"]) exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"]) in_price = trade["entry_price"] out_price = trade["exit_price"] if not buy_label_added: ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") buy_label_added = True else: ax.plot(entry_time, in_price, marker="^", color="green", markersize=10) if not sell_label_added: ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") sell_label_added = True else: ax.plot(exit_time, out_price, marker="v", color="red", markersize=10) ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") ax.set_xlabel("Time") ax.set_ylabel("Price") ax.legend() ax.grid(True) fig = ax.get_figure() fig.autofmt_xdate() # --- Simulation of Trades for Visualization --- def simulate_trades(model, env, device, args): """ Run a simulation on the current sliding window using the model's outputs and a decision rule. This simulation updates env.trade_history and is used for visualization only. """ env.reset() # resets the window and index while True: i = env.current_index state = env.get_state(i) current_open = env.candle_window[i]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() # Simple decision rule based on predicted move. if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: action = 2 # BUY elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: action = 0 # SELL else: action = 1 # HOLD _, _, _, done, _, _ = env.step(action) if done: break # --- Backtest Environment with Sliding Window and Order Info --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict # full candles dict across timeframes self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size self.reset() def reset(self): # Randomly select a sliding window from the full dataset. self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.current_index = 0 self.trade_history = [] self.position = None return self.get_state(self.current_index) def __len__(self): return self.window_size def get_order_features(self, index): """ Returns a list of 7 features for the order channel. If an order is open, the first element is 1.0 and the second is the normalized difference: (current open - entry_price) / current open. Otherwise, returns zeros. """ candle = self.candle_window[index] if self.position is None: return [0.0] * FEATURES_PER_CHANNEL else: flag = 1.0 diff = (candle["open"] - self.position["entry_price"]) / candle["open"] return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) def get_state(self, index): """ Build state features from: - For each timeframe: features from the aligned candle. - One extra channel: current order information. - NUM_INDICATORS channels of zeros. Each channel is a vector of length FEATURES_PER_CHANNEL. """ state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: if tf == self.base_tf: candle = self.candle_window[index] features = get_features_for_tf([candle], 0) else: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) # Append order channel. order_features = self.get_order_features(index) state_features.append(order_features) # Append technical indicator channels. for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): """ Execute one step in the environment: - action: 0 => SELL, 1 => HOLD, 2 => BUY. - Trades are recorded when a BUY is followed by a SELL. """ base = self.candle_window if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] reward = 0.0 if self.position is None: if action == 2: # BUY signal: enter at next open. self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL signal: exit at next open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade # Weight for the surrogate profit loss. for epoch in range(start_epoch, args.epochs): env.reset() # Resets the sliding window. loss_accum = 0.0 steps = len(env) - 1 # We assume steps over consecutive candle pairs. for i in range(steps): state = env.get_state(i) current_open = env.candle_window[i]["open"] actual_high = env.candle_window[i+1]["high"] actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) # Prediction loss (L1 error). L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) # Surrogate profit loss. profit_buy = pred_high - current_open # potential long gain profit_sell = current_open - pred_low # potential short gain L_trade = - torch.max(profit_buy, profit_sell) # Additional penalty if no strong signal is produced. current_open_tensor = torch.tensor(current_open, device=device) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) loss = L_pred + lambda_trade * L_trade + penalty_term optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() loss_accum += loss.item() scheduler.step() epoch_loss = loss_accum / steps print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") save_checkpoint(model, optimizer, epoch, loss_accum) simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1) # --- Live Plotting Functions (For Live Mode) --- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # --- Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") return parser.parse_args() def random_action(): return random.randint(0, 2) # --- Main Function --- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] hidden_dim = 128 # Total channels: NUM_TIMEFRAMES + 1 (order info) + NUM_INDICATORS. total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) if args.mode == 'train': candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No historical candle data available for backtesting.") return base_tf = "1m" env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None if not args.start_fresh: checkpoint = load_best_checkpoint(model) if checkpoint is not None: start_epoch = checkpoint.get("epoch", 0) + 1 print(f"Resuming training from epoch {start_epoch}.") else: print("No checkpoint found. Starting training from scratch.") else: print("Starting training from scratch as requested.") optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: optim_state = checkpoint.get("optimizer_state_dict", None) if optim_state is not None and "param_groups" in optim_state: optimizer.load_state_dict(optim_state) print("Loaded optimizer state from checkpoint.") else: print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': load_best_checkpoint(model) candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No cached candles available for live preview.") return env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() print("Starting live trading loop. (Using model-based decision rule.)") while True: state = env.get_state(env.current_index) current_open = env.candle_window[env.current_index]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: action = 2 elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: action = 0 else: action = 1 _, _, _, done, _, _ = env.step(action) if done: print("Reached end of simulation window; resetting environment.") env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") # Your inference logic goes here. else: print("Invalid mode specified.") if __name__ == "__main__": asyncio.run(main())