#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import argparse import threading import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime import matplotlib.pyplot as plt import math from torch.nn import TransformerEncoder, TransformerEncoderLayer import matplotlib.dates as mdates from dotenv import load_dotenv load_dotenv() import torch print(torch.cuda.is_available()) # Define global constants FIRST. CACHE_FILE = "candles_cache.json" TRAINING_CACHE_FILE = "training_cache.json" # --- Helper Function for Timestamp Conversion --- def convert_timestamp(ts): ts = float(ts) if ts > 1e10: ts /= 1000.0 return datetime.fromtimestamp(ts) # ------------------------------- # Historical Data Fetching Functions (Using CCXT) # ------------------------------- async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): candles = [] since_ms = since while True: try: batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size) except Exception as e: print("Error fetching historical data:", e) break if not batch: break for c in batch: candle_dict = { 'timestamp': c[0], 'open': c[1], 'high': c[2], 'low': c[3], 'close': c[4], 'volume': c[5] } candles.append(candle_dict) last_timestamp = batch[-1][0] if last_timestamp >= end_time: break since_ms = last_timestamp + 1 print(f"Fetched {len(candles)} candles for timeframe {timeframe}.") return candles async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500): cached_candles = load_candles_cache(cache_file) if cached_candles and timeframe in cached_candles: last_ts = cached_candles[timeframe][-1]['timestamp'] if last_ts < end_time: print("Fetching new candles to update cache...") new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size) cached_candles[timeframe].extend(new_candles) else: print("Cache covers the requested period.") return cached_candles[timeframe] else: candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size) return candles # ------------------------------- # Cache and Training Cache Helpers # ------------------------------- def load_candles_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: data = json.load(f) print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) return {} def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) def load_training_cache(filename): if os.path.exists(filename): try: with open(filename, "r") as f: cache = json.load(f) print(f"Loaded training cache from {filename}.") return cache except Exception as e: print("Error loading training cache:", e) return {"total_pnl": 0.0} def save_training_cache(filename, cache): try: with open(filename, "w") as f: json.dump(cache, f) except Exception as e: print("Error saving training cache:", e) # ------------------------------- # Checkpoint Functions # ------------------------------- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[:len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: loss = float(parts[1]) best_files.append((loss, file)) except Exception: continue return best_files def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: worst_loss, worst_file = max(best_models, key=lambda x: x[0]) if loss < worst_loss: add_to_best = True os.remove(os.path.join(best_dir, worst_file)) if add_to_best: best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) old_state = checkpoint["model_state_dict"] new_state = model.state_dict() if "timeframe_embed.weight" in old_state: old_embed = old_state["timeframe_embed.weight"] new_embed = new_state["timeframe_embed.weight"] if old_embed.shape[0] < new_embed.shape[0]: new_embed[:old_embed.shape[0]] = old_embed old_state["timeframe_embed.weight"] = new_embed model.load_state_dict(old_state, strict=False) return checkpoint # ------------------------------- # Positional Encoding and Transformer-Based Model # ------------------------------- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, dropout=0.1, activation='gelu', batch_first=True ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.GELU(), nn.Linear(hidden_dim//2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.GELU(), nn.Linear(hidden_dim//2, 1) ) def forward(self, x, timeframe_ids): batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) tf_embeds = self.timeframe_embed(timeframe_ids) stacked = stacked + tf_embeds.unsqueeze(0) transformer_out = self.transformer(stacked) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # ------------------------------- # Technical Indicator Helpers # ------------------------------- def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] return sum(values)/len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] return sum(values)/len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): best_idx = 0 for i, candle in enumerate(candles_list): if candle["timestamp"] <= target_ts: best_idx = i else: break return best_idx, candles_list[best_idx] def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) sma_volume= compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # ------------------------------- # Backtest Environment Class # ------------------------------- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size self.reset() def reset(self): self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] self.current_index = 0 self.trade_history = [] self.position = None return self.get_state(self.current_index) def __len__(self): return self.window_size def get_order_features(self, index): candle = self.candle_window[index] if self.position is None: return [0.0] * FEATURES_PER_CHANNEL else: flag = 1.0 diff = (candle["open"] - self.position["entry_price"]) / candle["open"] return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) def get_state(self, index): state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: if tf == self.base_tf: candle = self.candle_window[index] features = get_features_for_tf([candle], 0) else: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) order_features = self.get_order_features(index) state_features.append(order_features) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): base = self.candle_window if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] reward = 0.0 if self.position is None: if action == 2: self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: exit_price = next_candle["close"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low # ------------------------------- # Enhanced Training Loop # ------------------------------- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade training_cache = load_training_cache(TRAINING_CACHE_FILE) total_pnl = training_cache.get("total_pnl", 0.0) for epoch in range(start_epoch, args.epochs): env.reset() loss_accum = 0.0 steps = len(env) - 1 for i in range(steps): state = env.get_state(i) current_open = env.candle_window[i]["open"] actual_high = env.candle_window[i+1]["high"] actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) profit_buy = pred_high - current_open profit_sell = current_open - pred_low L_trade = - torch.max(profit_buy, profit_sell) current_open_tensor = torch.tensor(current_open, device=device) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) loss = L_pred + lambda_trade * L_trade + penalty_term optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() loss_accum += loss.item() scheduler.step() epoch_loss = loss_accum / steps if len(env.trade_history) == 0: epoch_loss *= 3 epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) total_pnl += epoch_pnl print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") save_checkpoint(model, optimizer, epoch, loss_accum) simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl) training_cache["total_pnl"] = total_pnl save_training_cache(TRAINING_CACHE_FILE, training_cache) # ------------------------------- # Live Plotting (for Live Mode) # ------------------------------- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # ------------------------------- # Live HTML Chart Update (with Volume and Loss) # ------------------------------- def update_live_html(candles, trade_history, epoch, loss, total_pnl): from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_content = f""" Live Trading Chart - Epoch {epoch}

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}

Live Chart
""" with open("live_chart.html", "w") as f: f.write(html_content) print("Updated live_chart.html.") # ------------------------------- # Chart Drawing Helpers (with Volume and Date+Time) # ------------------------------- def update_live_chart(ax, candles, trade_history): ax.clear() times = [convert_timestamp(candle["timestamp"]) for candle in candles] close_prices = [candle["close"] for candle in candles] ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.set_xlabel("Time") ax.set_ylabel("Price") ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) ax2 = ax.twinx() volumes = [candle["volume"] for candle in candles] if len(times) > 1: times_num = mdates.date2num(times) bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 else: bar_width = 0.01 ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") ax2.set_ylabel("Volume") for trade in trade_history: entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) in_price = trade["entry_price"] out_price = trade["exit_price"] ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines + lines2, labels + labels2) ax.grid(True) fig = ax.get_figure() fig.autofmt_xdate() # ------------------------------- # Backtest Environment Class # ------------------------------- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size self.reset() def reset(self): self.start_index = random.randint(0, len(self.full_candles)-self.window_size) self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] self.current_index = 0 self.trade_history = [] self.position = None return self.get_state(self.current_index) def __len__(self): return self.window_size def get_order_features(self, index): candle = self.candle_window[index] if self.position is None: return [0.0] * FEATURES_PER_CHANNEL else: flag = 1.0 diff = (candle["open"] - self.position["entry_price"]) / candle["open"] return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) def get_state(self, index): state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: if tf == self.base_tf: candle = self.candle_window[index] features = get_features_for_tf([candle], 0) else: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) order_features = self.get_order_features(index) state_features.append(order_features) for _ in range(NUM_INDICATORS): state_features.append([0.0]*FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): base = self.candle_window if self.current_index >= len(base)-1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] reward = 0.0 if self.position is None: if action == 2: # BUY self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL exit_price = next_candle["close"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], "entry_price": self.position["entry_price"], "exit_index": next_index, "exit_price": exit_price, "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index done = (self.current_index >= len(base)-1) actual_high = next_candle["high"] actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low # ------------------------------- # Enhanced Training Loop # ------------------------------- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade training_cache = load_training_cache(TRAINING_CACHE_FILE) total_pnl = training_cache.get("total_pnl", 0.0) for epoch in range(start_epoch, args.epochs): env.reset() loss_accum = 0.0 steps = len(env) - 1 for i in range(steps): state = env.get_state(i) current_open = env.candle_window[i]["open"] actual_high = env.candle_window[i+1]["high"] actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) profit_buy = pred_high - current_open profit_sell = current_open - pred_low L_trade = - torch.max(profit_buy, profit_sell) current_open_tensor = torch.tensor(current_open, device=device) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) loss = L_pred + lambda_trade * L_trade + penalty_term optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() loss_accum += loss.item() scheduler.step() epoch_loss = loss_accum / steps if len(env.trade_history) == 0: epoch_loss *= 3 epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) total_pnl += epoch_pnl print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") save_checkpoint(model, optimizer, epoch, loss_accum) simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl) training_cache["total_pnl"] = total_pnl save_training_cache(TRAINING_CACHE_FILE, training_cache) # ------------------------------- # Live Plotting (for Live Mode) # ------------------------------- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) while True: update_live_chart(ax, candles, env.trade_history) plt.draw() plt.pause(1) # ------------------------------- # Live HTML Chart Update (with Volume and Loss) # ------------------------------- def update_live_html(candles, trade_history, epoch, loss, total_pnl): from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_content = f""" Live Trading Chart - Epoch {epoch}

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}

Live Chart
""" with open("live_chart.html", "w") as f: f.write(html_content) print("Updated live_chart.html.") # ------------------------------- # Chart Drawing Helpers (with Volume and Date+Time) # ------------------------------- def update_live_chart(ax, candles, trade_history): ax.clear() times = [convert_timestamp(candle["timestamp"]) for candle in candles] close_prices = [candle["close"] for candle in candles] ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.set_xlabel("Time") ax.set_ylabel("Price") ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) ax2 = ax.twinx() volumes = [candle["volume"] for candle in candles] if len(times) > 1: times_num = mdates.date2num(times) bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 else: bar_width = 0.01 ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") ax2.set_ylabel("Volume") for trade in trade_history: entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) in_price = trade["entry_price"] out_price = trade["exit_price"] ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines + lines2, labels + labels2) ax.grid(True) fig = ax.get_figure() fig.autofmt_xdate() # ------------------------------- # Global Constants for Features # ------------------------------- NUM_INDICATORS = 20 FEATURES_PER_CHANNEL = 7 ORDER_CHANNELS = 1 # ------------------------------- # General Simulation of Trades Function # ------------------------------- def simulate_trades(model, env, device, args): if args.main_tf == "1s": simulate_trades_1s(env) return env.reset() while True: i = env.current_index if i >= len(env.candle_window) - 1: break state = env.get_state(i) current_open = env.candle_window[i]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: if (pred_high - current_open) >= (current_open - pred_low): action = 2 else: action = 0 _, _, _, done, _, _ = env.step(action) else: manual_trade(env) if env.current_index >= len(env.candle_window) - 1: break # ------------------------------- # Argument Parsing # ------------------------------- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.") parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--main_tf', type=str, default='1m', help="Desired main timeframe to focus on (e.g., '1s' or '1m').") # Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior. parser.add_argument('--no-fetch', dest='fetch', action='store_false', help="Do NOT fetch fresh data from exchange on start.") parser.set_defaults(fetch=True) parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.") return parser.parse_args() def random_action(): return random.randint(0, 2) # ------------------------------- # Main Function # ------------------------------- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # With fetch defaulting to True, live mode will always try to top-up the cache. if args.fetch: import ccxt.async_support as ccxt exchange = ccxt.binance({'enableRateLimit': True}) now_ms = int(time.time()*1000) cached = load_candles_cache(CACHE_FILE) if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0: last_ts = cached[args.main_tf][-1]['timestamp'] since = last_ts + 1 else: since = now_ms - 2*24*60*60*1000 print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...") fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms) candles_dict = {args.main_tf: fresh_candles} save_candles_cache(CACHE_FILE, candles_dict) await exchange.close() else: candles_dict = load_candles_cache(CACHE_FILE) if not candles_dict: print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.") return default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] timeframes = [tf for tf in default_timeframes if tf in candles_dict] if args.main_tf not in timeframes: print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}") return base_tf = args.main_tf hidden_dim = 128 total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS model = TradingModel(total_channels, len(timeframes)).to(device) if args.mode == 'train': env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None if not args.start_fresh: checkpoint = load_best_checkpoint(model) if checkpoint is not None: start_epoch = checkpoint.get("epoch", 0) + 1 print(f"Resuming training from epoch {start_epoch}.") else: print("No checkpoint found. Starting training from scratch.") else: print("Starting training from scratch as requested.") optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: optim_state = checkpoint.get("optimizer_state_dict", None) if optim_state is not None and "param_groups" in optim_state: try: optimizer.load_state_dict(optim_state) print("Loaded optimizer state from checkpoint.") except Exception as e: print("Failed to load optimizer state due to:", e) print("Deleting all checkpoints and starting fresh.") for chk_dir in [LAST_DIR, BEST_DIR]: for f in os.listdir(chk_dir): os.remove(os.path.join(chk_dir, f)) else: print("No valid optimizer state found; using fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': import ccxt.async_support as ccxt exchange = ccxt.binance({'enableRateLimit': True}) POLL_INTERVAL = 60 # seconds async def update_live_candles(): nonlocal exchange, args, candles_dict while True: now_ms = int(time.time()*1000) new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since=now_ms - 2*60*1000, end_time=now_ms) if args.main_tf in candles_dict: candles_dict[args.main_tf] = new_candles else: candles_dict[args.main_tf] = new_candles print("Live candles updated.") await asyncio.sleep(POLL_INTERVAL) asyncio.create_task(update_live_candles()) load_best_checkpoint(model) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() print("Starting live trading loop. (Using live updated data now.)") while True: if args.main_tf == "1s": simulate_trades_1s(env) else: state = env.get_state(env.current_index) current_open = env.candle_window[env.current_index]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: if (pred_high - current_open) >= (current_open - pred_low): action = 2 else: action = 0 _, _, _, done, _, _ = env.step(action) else: manual_trade(env) if env.current_index >= len(env.candle_window) - 1: print("Reached end of simulation window; resetting environment.") env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") else: print("Invalid mode specified.") if __name__ == "__main__": asyncio.run(main())