#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import argparse import threading import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime import matplotlib.pyplot as plt import ccxt.async_support as ccxt from torch.nn import TransformerEncoder, TransformerEncoderLayer import math from dotenv import load_dotenv load_dotenv() # --- Directories --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # --- Constants --- NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # e.g., 20 technical indicators FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) # --- Enhanced Transformer Model --- class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() # Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features) self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, dropout=0.1, activation='gelu', batch_first=False ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] # Add embedding for each channel. tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) transformer_out = self.transformer(stacked, mask=src_mask) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Technical Indicator Helpers --- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] return sum(values) / len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] return sum(values) / len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume = compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # --- Caching & Checkpoint Functions --- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[: len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: # parts[1] is the recorded loss loss = float(parts[1]) best_files.append((loss, file)) except Exception: continue return best_files def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: worst_loss, worst_file = max(best_models, key=lambda x: x[0]) if loss < worst_loss: add_to_best = True os.remove(os.path.join(best_dir, worst_file)) if add_to_best: best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ Generate a chart image with buy/sell markers and a dotted line between open/close positions, then embed it in a simple HTML page that auto-refreshes every 10 seconds. """ from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) ax.set_title(f"Live Trading Chart - Epoch {epoch}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_content = f""" Live Trading Chart - Epoch {epoch}

Live Trading Chart - Epoch {epoch}

Live Chart
""" with open("live_chart.html", "w") as f: f.write(html_content) print("Updated live_chart.html.") # --- Chart Drawing Helpers (used by both live preview and HTML update) --- def update_live_chart(ax, candles, trade_history): """ Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit. """ ax.clear() close_prices = [candle["close"] for candle in candles] x = list(range(len(close_prices))) ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1) buy_label_added = False sell_label_added = False for trade in trade_history: in_idx = trade["entry_index"] out_idx = trade["exit_index"] in_price = trade["entry_price"] out_price = trade["exit_price"] if not buy_label_added: ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY") buy_label_added = True else: ax.plot(in_idx, in_price, marker="^", color="green", markersize=10) if not sell_label_added: ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL") sell_label_added = True else: ax.plot(out_idx, out_price, marker="v", color="red", markersize=10) ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue") ax.set_xlabel("Candle Index") ax.set_ylabel("Price") ax.legend() ax.grid(True) # --- Forced Action & Optimal Hint Helpers --- def get_forced_action(env): """ When simulating streaming data, we force a trade at strategic moments: - At the very first step: force BUY. - At the penultimate step: if a position is open, force SELL. - Otherwise, default to HOLD. (The environment will also apply a penalty if the chosen action does not match the optimal hint.) """ total = len(env) if env.current_index == 0: return 2 # BUY elif env.current_index >= total - 2: if env.position is not None: return 0 # SELL else: return 1 # HOLD else: return 1 # HOLD # --- Backtest Environment with Sliding Window and Hints --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict # full dictionary of timeframe candles self.base_tf = base_tf self.timeframes = timeframes # Use maximum allowed candles for the base timeframe. self.full_candles = candles_dict[base_tf] # Determine sliding window size: if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price) self.reset() def reset(self): # Pick a random sliding window from the full dataset. self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] self.current_index = 0 self.trade_history = [] self.position = None return self.get_state(self.current_index) def __len__(self): return self.window_size def get_state(self, index): """ Build state features by taking the candle at the current index for the base timeframe (from the sliding window) and aligning candles for other timeframes. Then append zeros for technical indicators. """ state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: if tf == self.base_tf: # For base timeframe, use the sliding window candle. candle = self.candle_window[index] features = get_features_for_tf([candle], 0) # List of one element else: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def compute_optimal_hint(self, horizon=10, threshold=0.005): """ Using a lookahead window from the sliding window (future candles) determine an optimal action hint: 2: BUY if price is expected to rise at least by threshold. 0: SELL if expected to drop by threshold. 1: HOLD otherwise. """ base = self.candle_window if self.current_index >= len(base) - 1: return 1 # Hold current_candle = base[self.current_index] open_price = current_candle["open"] future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))] if not future_slice: return 1 max_future = max(candle["high"] for candle in future_slice) min_future = min(candle["low"] for candle in future_slice) if (max_future - open_price) / open_price >= threshold: return 2 # BUY elif (open_price - min_future) / open_price >= threshold: return 0 # SELL else: return 1 # HOLD def step(self, action): base = self.candle_window if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] reward = 0.0 # Trade logic (0: SELL, 1: HOLD, 2: BUY) if self.position is None: if action == 2: # BUY: enter at next candle's open. entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: if action == 0: # SELL: exit at next candle's open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] actual_low = next_candle["low"] # Compute optimal action hint and apply a penalty if action deviates. optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005) if action != optimal_hint: reward -= self.hint_penalty * next_candle["open"] return current_state, reward, next_state, done, actual_high, actual_low # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): for epoch in range(start_epoch, args.epochs): state = env.reset() total_loss = 0.0 model.train() while True: # Use forced-action policy for trading (guaranteeing at least one trade per episode) action = get_forced_action(env) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) # Use our forced action in the environment step. _, reward, next_state, done, actual_high, actual_low = env.step(action) target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) high_loss = torch.abs(pred_high - target_high) * 2 low_loss = torch.abs(pred_low - target_low) * 2 loss = (high_loss + low_loss).mean() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() if done: break state = next_state scheduler.step() epoch_loss = total_loss / len(env) print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") save_checkpoint(model, optimizer, epoch, total_loss) # Update live HTML chart to display the current sliding window update_live_html(env.candle_window, env.trade_history, epoch+1) # --- Live Plotting Functions (For live mode) --- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # --- Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train','live','inference'], default='train') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005) # If set, training starts from scratch (ignoring saved checkpoints) parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.') return parser.parse_args() def random_action(): return random.randint(0, 2) # --- Main Function --- async def main(): args = parse_args() # Use GPU if available; else CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] hidden_dim = 128 total_channels = NUM_TIMEFRAMES + NUM_INDICATORS model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) if args.mode == 'train': candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No historical candle data available for backtesting.") return base_tf = "1m" # Create the environment with a sliding window (simulate streaming data) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None if not args.start_fresh: checkpoint = load_best_checkpoint(model) if checkpoint is not None: start_epoch = checkpoint.get("epoch", 0) + 1 print(f"Resuming training from epoch {start_epoch}.") else: print("No checkpoint found. Starting training from scratch.") else: print("Starting training from scratch as requested.") optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: optim_state = checkpoint.get("optimizer_state_dict", None) if optim_state is not None and "param_groups" in optim_state: optimizer.load_state_dict(optim_state) print("Loaded optimizer state from checkpoint.") else: print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': load_best_checkpoint(model) candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No cached candles available for live preview.") return env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() print("Starting live trading loop. (Using forced-action policy for simulation.)") while True: action = get_forced_action(env) state, reward, next_state, done, _, _ = env.step(action) if done: print("Reached end of simulation window, resetting environment.") state = env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") # Apply a similar (or learned) policy as needed. else: print("Invalid mode specified.") if __name__ == "__main__": asyncio.run(main())