#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import argparse import threading import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime import matplotlib.pyplot as plt import math from torch.nn import TransformerEncoder, TransformerEncoderLayer import matplotlib.dates as mdates from dotenv import load_dotenv load_dotenv() # Define global constants FIRST. CACHE_FILE = "candles_cache.json" TRAINING_CACHE_FILE = "training_cache.json" # --- Helper Function for Timestamp Conversion --- def convert_timestamp(ts): ts = float(ts) if ts > 1e10: # Handle milliseconds ts /= 1000.0 return datetime.fromtimestamp(ts) # ------------------------------- # Historical Data Fetching Functions (Using CCXT) # ------------------------------- async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): candles = [] since_ms = since while True: try: batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size) except Exception as e: print("Error fetching historical data:", e) break if not batch: break for c in batch: candle_dict = { 'timestamp': c[0], 'open': c[1], 'high': c[2], 'low': c[3], 'close': c[4], 'volume': c[5] } candles.append(candle_dict) last_timestamp = batch[-1][0] if last_timestamp >= end_time: break since_ms = last_timestamp + 1 print(f"Fetched {len(candles)} candles for timeframe {timeframe}.") return candles async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500): cached_candles = load_candles_cache(cache_file) if cached_candles and timeframe in cached_candles: last_ts = cached_candles[timeframe][-1]['timestamp'] if last_ts < end_time: print("Fetching new candles to update cache...") new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size) cached_candles[timeframe].extend(new_candles) else: print("Cache covers the requested period.") return cached_candles[timeframe] else: candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size) return candles # ------------------------------- # Cache and Training Cache Helpers # ------------------------------- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} # Return empty dict if no cache def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) def load_training_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: cache = json.load(f) print(f"Loaded training cache from {filename}.") return cache except Exception as e: print("Error loading training cache:", e) return {"total_pnl": 0.0} # Initialize if not found def save_training_cache(filename, cache): try: with open(filename, "w") as f: json.dump(cache, f) except Exception as e: print("Error saving training cache:", e) # ------------------------------- # Checkpoint Functions # ------------------------------- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[:len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: loss = float(parts[1]) # Get loss from filename best_files.append((loss, file)) except Exception: continue return best_files def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: worst_loss, worst_file = max(best_models, key=lambda x: x[0]) if loss < worst_loss: # Save if better than worst add_to_best = True os.remove(os.path.join(best_dir, worst_file)) # Remove worst if add_to_best: best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" # Include loss in name best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) # Load best (lowest loss) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) # Handle potential embedding size mismatch old_state = checkpoint["model_state_dict"] new_state = model.state_dict() if "timeframe_embed.weight" in old_state: old_embed = old_state["timeframe_embed.weight"] new_embed = new_state["timeframe_embed.weight"] if old_embed.shape[0] < new_embed.shape[0]: # Copy old embeddings to the new embedding, handling size increase new_embed[:old_embed.shape[0]] = old_embed old_state["timeframe_embed.weight"] = new_embed model.load_state_dict(old_state, strict=False) # Allow for size differences return checkpoint # ------------------------------- # Positional Encoding and Transformer-Based Model # ------------------------------- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + self.pe[:x.size(0)] return self.dropout(x) class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) # Embedding for each timeframe self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) # Increased number of layers and heads for larger model encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=8, dim_feedforward=2048, # Increased nhead and dim_feedforward dropout=0.1, activation='gelu', batch_first=True ) self.transformer = TransformerEncoder(encoder_layers, num_layers=6) # More layers # Attention pooling to aggregate channel outputs self.attn_pool = nn.Linear(hidden_dim, 1) # Separate prediction heads for high and low self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): # x shape: (batch, num_channels, features_per_channel) batch_size, num_channels, _ = x.shape channel_outs = [] # Process each channel through its branch for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) # Stack channel outputs stacked = torch.stack(channel_outs, dim=1) # (batch, num_channels, hidden_dim) # Add timeframe embeddings tf_embeds = self.timeframe_embed(timeframe_ids) # (num_timeframes, hidden_dim) stacked = stacked + tf_embeds.unsqueeze(0) # Add to each item in batch # Transformer transformer_out = self.transformer(stacked) # (batch, num_channels, hidden_dim) # Attention pooling attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) # (batch, num_channels, 1) aggregated = (transformer_out * attn_weights).sum(dim=1) # (batch, hidden_dim) # Predict high and low return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # ------------------------------- # Technical Indicator Helpers # ------------------------------- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index + 1]] return sum(values) / len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index + 1]] return sum(values) / len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): """Find the candle in the list whose timestamp is the largest that is <= target_ts.""" best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break # Stop once we go past the target return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): """Return a vector of 7 features: open, high, low, close, volume, sma_close, sma_volume.""" candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume = compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # ------------------------------- # Backtest Environment Class # ------------------------------- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] # All candles for base timeframe # Define window size (or use a reasonable default if not provided) if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) # Use 100 or total length self.window_size = window_size self.reset() # Initialize def reset(self): # Randomly select a starting point for the window self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.candle_window = self.full_candles[self.start_index:self.start_index + self.window_size] self.current_index = 0 self.trade_history = [] self.position = None # Track if we're in a trade: None, or {"entry_price": ..., "entry_index": ...} return self.get_state(self.current_index) # Return initial state def __len__(self): return self.window_size # Length of the environment is the window size def get_order_features(self, index): """Get features related to the current order (if any).""" candle = self.candle_window[index] if self.position is None: # No position: all zeros return [0.0] * FEATURES_PER_CHANNEL # 7 zeros else: # In a position: [1.0, price_diff, 0, 0, 0, 0, 0] flag = 1.0 diff = (candle["open"] - self.position["entry_price"]) / candle["open"] # Relative difference return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) def get_state(self, index): """Construct state for the given index.""" state_features = [] base_ts = self.candle_window[index]["timestamp"] # Get features for each timeframe for tf in self.timeframes: if tf == self.base_tf: # For the base timeframe, use the candle directly from the window candle = self.candle_window[index] features = get_features_for_tf([candle], 0) # Pass as a list with single candle else: # For other timeframes, align with the base timestamp aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) # Add order features order_features = self.get_order_features(index) state_features.append(order_features) # Add placeholder channels for additional indicators (if needed) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): """Take a step, given an action.""" base = self.candle_window # Shorter name if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 # No reward at very end, and done current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] # Next candle for open, high, low reward = 0.0 # Handle actions (simplified for clarity) if self.position is None: if action == 2: # BUY self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL exit_price = next_candle["close"] reward = exit_price - self.position["entry_price"] # PnL is reward trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None # Exit position self.current_index = next_index done = (self.current_index >= len(base) - 1) # Done if at end of window actual_high = next_candle["high"] actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low # Return next high/low # ------------------------------- # Enhanced Training Loop # ------------------------------- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade # Weight for trade loss # Load training cache (for total PnL tracking) training_cache = load_training_cache(TRAINING_CACHE_FILE) total_pnl = training_cache.get("total_pnl", 0.0) for epoch in range(start_epoch, args.epochs): env.reset() # Reset environment for each epoch loss_accum = 0.0 steps = len(env) - 1 # Number of steps in the episode for i in range(steps): # Iterate through the episode state = env.get_state(i) current_open = env.candle_window[i]["open"] # Current candle's open actual_high = env.candle_window[i + 1]["high"] # Next candle's high actual_low = env.candle_window[i + 1]["low"] # Next candle's low # Forward pass state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) # Add batch dimension timeframe_ids = torch.arange(state.shape[0]).to(device) # Create timeframe IDs pred_high, pred_low = model(state_tensor, timeframe_ids) # Calculate prediction loss (L_pred) L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) # Calculate trade surrogate loss (L_trade) profit_buy = pred_high - current_open # Potential profit if buying profit_sell = current_open - pred_low # Potential profit if selling L_trade = - torch.max(profit_buy, profit_sell) # Minimize negative profit # Calculate no-action penalty (encourage taking action when profitable) current_open_tensor = torch.tensor(current_open, device=device) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) # Total loss loss = L_pred + lambda_trade * L_trade + penalty_term # Backpropagation optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping optimizer.step() scheduler.step() loss_accum += loss.item() epoch_loss = loss_accum / steps # Average loss per step if len(env.trade_history) == 0: epoch_loss *= 3 epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) # PnL for the epoch total_pnl += epoch_pnl # Update total PnL print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") save_checkpoint(model, optimizer, epoch, loss_accum) # Save with accumulated loss simulate_trades(model, env, device, args) # Simulate trades after each epoch update_live_html(env.candle_window, env.trade_history, epoch + 1, epoch_loss, total_pnl) # Update HTML visualization # Save training cache (for total PnL tracking) training_cache["total_pnl"] = total_pnl save_training_cache(TRAINING_CACHE_FILE, training_cache) # ------------------------------- # Live Plotting (for Live Mode) # ------------------------------- def live_preview_loop(candles, env): plt.ion() # Interactive mode fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # Update every second # ------------------------------- # Live HTML Chart Update (with Volume and Loss) # ------------------------------- def update_live_html(candles, trade_history, epoch, loss, total_pnl): from io import BytesIO import base64 # Create a new figure and axes for each update fig, ax = plt.subplots(figsize=(12, 6)) # Draw the chart update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) # PnL for this window ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}") # Save the plot to a BytesIO buffer buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) # Close the figure to free memory buf.seek(0) image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') # Generate HTML content html_content = f"""