#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import argparse import threading import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime import matplotlib.pyplot as plt import math from torch.nn import TransformerEncoder, TransformerEncoderLayer import matplotlib.dates as mdates from dotenv import load_dotenv load_dotenv() # --- Fetch Real 1s Data (if main_tf=="1s") --- def fetch_real_1s_data(): """ Fetch real 1-second candle data from your API. Replace the URL and parameters with those required by your data provider. Expected data format: a list of dictionaries with keys: "timestamp", "open", "high", "low", "close", "volume" """ import requests url = "https://api.example.com/1s-data" # <-- Replace with your actual endpoint. try: response = requests.get(url) response.raise_for_status() data = response.json() print("Fetched real 1s data successfully.") return data except Exception as e: print("Failed to fetch real 1s data:", e) return [] # --- Helper Function for Timestamp Conversion --- def convert_timestamp(ts): """ Safely converts a timestamp to a datetime object. If the timestamp is abnormally high (e.g. in milliseconds), it is divided by 1000. """ ts = float(ts) if ts > 1e10: # Likely in milliseconds ts /= 1000.0 return datetime.fromtimestamp(ts) # --- Directories --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" TRAINING_CACHE_FILE = "training_cache.json" # --- Constants --- NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # e.g., 20 technical indicators FEATURES_PER_CHANNEL = 7 # Each channel has 7 features. ORDER_CHANNELS = 1 # One extra channel for order information. # --- Training Cache Helpers --- def load_training_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: cache = json.load(f) print(f"Loaded training cache from {filename}.") return cache except Exception as e: print("Error loading training cache:", e) return {"total_pnl": 0.0} def save_training_cache(filename, cache): try: with open(filename, "w") as f: json.dump(cache, f) except Exception as e: print("Error saving training cache:", e) # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) # --- Enhanced Transformer Model --- class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() # One branch per channel. self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, dropout=0.1, activation='gelu', batch_first=True # avoid nested tensor warning ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): # x: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] tf_embeds = self.timeframe_embed(timeframe_ids) # [num_channels, hidden] stacked = stacked + tf_embeds.unsqueeze(0) # add embeddings (broadcast along batch) transformer_out = self.transformer(stacked) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Technical Indicator Helpers --- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] return sum(values)/len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] return sum(values)/len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume= compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # --- Caching & Checkpoint Functions --- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[:len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: loss = float(parts[1]) best_files.append((loss, file)) except Exception: continue return best_files def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: worst_loss, worst_file = max(best_models, key=lambda x: x[0]) if loss < worst_loss: add_to_best = True os.remove(os.path.join(best_dir, worst_file)) if add_to_best: best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) old_state = checkpoint["model_state_dict"] new_state = model.state_dict() if "timeframe_embed.weight" in old_state: old_embed = old_state["timeframe_embed.weight"] new_embed = new_state["timeframe_embed.weight"] if old_embed.shape[0] < new_embed.shape[0]: new_embed[:old_embed.shape[0]] = old_embed old_state["timeframe_embed.weight"] = new_embed model.load_state_dict(old_state, strict=False) return checkpoint # --- Function for Manual Trade Override --- def manual_trade(env): """ When no sufficient action is taken by the model, use a fallback: Scan the remaining window for the global maximum and minimum. If the maximum occurs before the minimum, simulate a short trade; otherwise simulate a long trade. Use the candle "close" prices to compute trade reward. """ current_index = env.current_index if current_index >= len(env.candle_window) - 1: env.current_index = len(env.candle_window) - 1 return max_val = -float('inf') min_val = float('inf') i_max = current_index i_min = current_index for j in range(current_index + 1, len(env.candle_window)): high_j = env.candle_window[j]["high"] low_j = env.candle_window[j]["low"] if high_j > max_val: max_val = high_j i_max = j if low_j < min_val: min_val = low_j i_min = j if i_max < i_min: entry_price = env.candle_window[current_index]["close"] exit_price = env.candle_window[i_min]["close"] reward = entry_price - exit_price trade = { "entry_index": current_index, "entry_price": entry_price, "exit_index": i_min, "exit_price": exit_price, "pnl": reward } else: entry_price = env.candle_window[current_index]["close"] exit_price = env.candle_window[i_max]["close"] reward = exit_price - entry_price trade = { "entry_index": current_index, "entry_price": entry_price, "exit_index": i_max, "exit_price": exit_price, "pnl": reward } env.trade_history.append(trade) env.current_index = trade["exit_index"] # --- Simulation for 1s Data Using Local Extrema --- def simulate_trades_1s(env): """ When the main timeframe is 1s, scan the entire remaining window to detect local extrema. If at least two extrema are found, pair consecutive extrema as trades. Use the candle "close" prices for trade reward calculation. If too few extrema are found, fallback to manual_trade. """ n = len(env.candle_window) extrema = [] for i in range(env.current_index, n): if i == env.current_index or i == n-1: extrema.append(i) else: prev = env.candle_window[i-1]["close"] curr = env.candle_window[i]["close"] nex = env.candle_window[i+1]["close"] if curr < prev and curr < nex: extrema.append(i) elif curr > prev and curr > nex: extrema.append(i) if len(extrema) < 2: manual_trade(env) return for j in range(len(extrema)-1): entry_idx = extrema[j] exit_idx = extrema[j+1] entry_price = env.candle_window[entry_idx]["close"] exit_price = env.candle_window[exit_idx]["close"] if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]: reward = exit_price - entry_price else: reward = entry_price - exit_price trade = { "entry_index": entry_idx, "entry_price": entry_price, "exit_index": exit_idx, "exit_price": exit_price, "pnl": reward } env.trade_history.append(trade) env.current_index = n - 1 # --- General Simulation of Trades --- def simulate_trades(model, env, device, args): """ Simulate trades over the current sliding window. If the main timeframe is 1s, use local extrema detection. Otherwise, check if the model's predicted potentials exceed the threshold. Use manual_trade if the model's signal is too weak. """ if args.main_tf == "1s": simulate_trades_1s(env) return env.reset() while True: i = env.current_index if i >= len(env.candle_window) - 1: break state = env.get_state(i) current_open = env.candle_window[i]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: if (pred_high - current_open) >= (current_open - pred_low): action = 2 # BUY else: action = 0 # SELL _, _, _, done, _, _ = env.step(action) else: manual_trade(env) if env.current_index >= len(env.candle_window) - 1: break # --- Live HTML Chart Update (with Volume and Loss) --- def update_live_html(candles, trade_history, epoch, loss, total_pnl): """ Generate an HTML page with a live chart. The chart displays price (line) and volume (bars on a secondary y-axis), and includes trade markers with dotted connecting lines. The title shows the epoch, loss, and total PnL. The page auto-refreshes every 1 second. """ from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_content = f""" Live Trading Chart - Epoch {epoch}

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}

Live Chart
""" with open("live_chart.html", "w") as f: f.write(html_content) print("Updated live_chart.html.") # --- Chart Drawing Helpers (with Volume and Date+Time) --- def update_live_chart(ax, candles, trade_history): """ Plot the price chart with actual timestamps (date and time in short format) and volume on a secondary y-axis. Mark trade entry (green) and exit (red) points, with dotted lines connecting them. """ ax.clear() times = [convert_timestamp(candle["timestamp"]) for candle in candles] close_prices = [candle["close"] for candle in candles] ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.set_xlabel("Time") ax.set_ylabel("Price") ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) ax2 = ax.twinx() volumes = [candle["volume"] for candle in candles] if len(times) > 1: times_num = mdates.date2num(times) bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 else: bar_width = 0.01 ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") ax2.set_ylabel("Volume") for trade in trade_history: entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) in_price = trade["entry_price"] out_price = trade["exit_price"] ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines + lines2, labels + labels2) ax.grid(True) fig = ax.get_figure() fig.autofmt_xdate() # --- Backtest Environment with Sliding Window and Order Info --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict # full candles dict across timeframes self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size self.reset() def reset(self): self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.current_index = 0 self.trade_history = [] self.position = None return self.get_state(self.current_index) def __len__(self): return self.window_size def get_order_features(self, index): candle = self.candle_window[index] if self.position is None: return [0.0] * FEATURES_PER_CHANNEL else: flag = 1.0 diff = (candle["open"] - self.position["entry_price"]) / candle["open"] return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) def get_state(self, index): state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: if tf == self.base_tf: candle = self.candle_window[index] features = get_features_for_tf([candle], 0) else: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) order_features = self.get_order_features(index) state_features.append(order_features) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): base = self.candle_window if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] reward = 0.0 if self.position is None: if action == 2: # BUY (open long) self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL (close trade) exit_price = next_candle["close"] # use close price reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade # Load any saved total PnL from training cache: training_cache = load_training_cache(TRAINING_CACHE_FILE) total_pnl = training_cache.get("total_pnl", 0.0) for epoch in range(start_epoch, args.epochs): env.reset() loss_accum = 0.0 steps = len(env) - 1 for i in range(steps): state = env.get_state(i) current_open = env.candle_window[i]["open"] actual_high = env.candle_window[i+1]["high"] actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) profit_buy = pred_high - current_open profit_sell = current_open - pred_low L_trade = - torch.max(profit_buy, profit_sell) current_open_tensor = torch.tensor(current_open, device=device) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) loss = L_pred + lambda_trade * L_trade + penalty_term optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() loss_accum += loss.item() scheduler.step() epoch_loss = loss_accum / steps # If no trades occurred during the epoch, multiply the loss by 3. if len(env.trade_history) == 0: epoch_loss *= 3 epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) total_pnl += epoch_pnl print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") save_checkpoint(model, optimizer, epoch, loss_accum) simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl) # Update training cache with the new total PnL: training_cache["total_pnl"] = total_pnl save_training_cache(TRAINING_CACHE_FILE, training_cache) # --- Live Plotting (for Live Mode) --- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # --- Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.") parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--main_tf', type=str, default='1m', help="Desired main timeframe to focus on (e.g., '1s' or '1m').") return parser.parse_args() def random_action(): return random.randint(0, 2) # --- Main Function --- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # Load cached candles. candles_dict = load_candles_cache(CACHE_FILE) # If the desired main timeframe is 1s, attempt to fetch real 1s data. if args.main_tf == "1s": real_1s_data = fetch_real_1s_data() if real_1s_data: candles_dict["1s"] = real_1s_data if not candles_dict: print("No historical candle data available for backtesting.") return # Define desired timeframes list; if available, include "1s". default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] timeframes = [tf for tf in default_timeframes if tf in candles_dict] if args.main_tf not in timeframes: print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}") return base_tf = args.main_tf # Set the main timeframe as the base for the environment. hidden_dim = 128 # Total channels: number of timeframes + 1 order channel + NUM_INDICATORS. total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS model = TradingModel(total_channels, len(timeframes)).to(device) if args.mode == 'train': env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None if not args.start_fresh: checkpoint = load_best_checkpoint(model) if checkpoint is not None: start_epoch = checkpoint.get("epoch", 0) + 1 print(f"Resuming training from epoch {start_epoch}.") else: print("No checkpoint found. Starting training from scratch.") else: print("Starting training from scratch as requested.") optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: optim_state = checkpoint.get("optimizer_state_dict", None) if optim_state is not None and "param_groups" in optim_state: try: optimizer.load_state_dict(optim_state) print("Loaded optimizer state from checkpoint.") except Exception as e: print("Failed to load optimizer state due to:", e) print("Deleting all checkpoints and starting fresh.") for chk_dir in [LAST_DIR, BEST_DIR]: for f in os.listdir(chk_dir): os.remove(os.path.join(chk_dir, f)) else: print("No valid optimizer state found; using fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': load_best_checkpoint(model) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf)) while True: if args.main_tf == "1s": simulate_trades_1s(env) else: state = env.get_state(env.current_index) current_open = env.candle_window[env.current_index]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: if (pred_high - current_open) >= (current_open - pred_low): action = 2 else: action = 0 _, _, _, done, _, _ = env.step(action) else: manual_trade(env) if env.current_index >= len(env.candle_window)-1: print("Reached end of simulation window; resetting environment.") env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") # Inference logic goes here. else: print("Invalid mode specified.") if __name__ == "__main__": asyncio.run(main())