From 0e7997d50ae090b35f4ebcb562b19758907a61d3 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 5 Feb 2025 11:32:21 +0200 Subject: [PATCH] live data and refactoring --- crypto/brian/index-deep-new.py | 954 ++++++++++++++++++++------------- 1 file changed, 571 insertions(+), 383 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 8a2c8a9..5d085ff 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -22,24 +22,11 @@ import matplotlib.dates as mdates from dotenv import load_dotenv load_dotenv() -# --- Fetch Real 1s Data (if main_tf=="1s") --- -def fetch_real_1s_data(): - """ - Fetch real 1-second candle data from your API. - Replace the URL and parameters with those required by your data provider. - Expected data format: a list of dictionaries with keys: "timestamp", "open", "high", "low", "close", "volume" - """ - import requests - url = "https://api.example.com/1s-data" # <-- Replace with your actual endpoint. - try: - response = requests.get(url) - response.raise_for_status() - data = response.json() - print("Fetched real 1s data successfully.") - return data - except Exception as e: - print("Failed to fetch real 1s data:", e) - return [] + +# Define global constants FIRST. +CACHE_FILE = "candles_cache.json" +TRAINING_CACHE_FILE = "training_cache.json" + # --- Helper Function for Timestamp Conversion --- def convert_timestamp(ts): @@ -53,21 +40,78 @@ def convert_timestamp(ts): ts /= 1000.0 return datetime.fromtimestamp(ts) -# --- Directories --- -LAST_DIR = os.path.join("models", "last") -BEST_DIR = os.path.join("models", "best") -os.makedirs(LAST_DIR, exist_ok=True) -os.makedirs(BEST_DIR, exist_ok=True) -CACHE_FILE = "candles_cache.json" -TRAINING_CACHE_FILE = "training_cache.json" +# ------------------------------- +# Historical Data Fetching Functions (Using CCXT) +# ------------------------------- +async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): + """ + Fetch historical OHLCV data for the given symbol and timeframe. + "since" and "end_time" are in milliseconds. + """ + candles = [] + since_ms = since + while True: + try: + batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size) + except Exception as e: + print("Error fetching historical data:", e) + break + if not batch: + break + for c in batch: + candle_dict = { + 'timestamp': c[0], + 'open': c[1], + 'high': c[2], + 'low': c[3], + 'close': c[4], + 'volume': c[5] + } + candles.append(candle_dict) + last_timestamp = batch[-1][0] + if last_timestamp >= end_time: + break + since_ms = last_timestamp + 1 + print(f"Fetched {len(candles)} candles for timeframe {timeframe}.") + return candles -# --- Constants --- -NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"] -NUM_INDICATORS = 20 # e.g., 20 technical indicators -FEATURES_PER_CHANNEL = 7 # Each channel has 7 features. -ORDER_CHANNELS = 1 # One extra channel for order information. +async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500): + cached_candles = load_candles_cache(cache_file) + if cached_candles and timeframe in cached_candles: + last_ts = cached_candles[timeframe][-1]['timestamp'] + if last_ts < end_time: + print("Fetching new candles to update cache...") + new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size) + cached_candles[timeframe].extend(new_candles) + else: + print("Cache covers the requested period.") + return cached_candles[timeframe] + else: + candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size) + return candles + + +# ------------------------------- +# Cache and Training Cache Helpers +# ------------------------------- +def load_candles_cache(filename): + if os.path.exists(filename): + try: + with open(filename, "r") as f: + data = json.load(f) + print(f"Loaded cached data from {filename}.") + return data + except Exception as e: + print("Error reading cache file:", e) + return {} + +def save_candles_cache(filename, candles_dict): + try: + with open(filename, "w") as f: + json.dump(candles_dict, f) + except Exception as e: + print("Error saving cache file:", e) -# --- Training Cache Helpers --- def load_training_cache(filename): if os.path.exists(filename): try: @@ -86,116 +130,15 @@ def save_training_cache(filename, cache): except Exception as e: print("Error saving training cache:", e) -# --- Positional Encoding Module --- -class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe = torch.zeros(max_len, 1, d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - def forward(self, x): - x = x + self.pe[:x.size(0)] - return self.dropout(x) +TRAINING_CACHE_FILE = "training_cache.json" -# --- Enhanced Transformer Model --- -class TradingModel(nn.Module): - def __init__(self, num_channels, num_timeframes, hidden_dim=128): - super().__init__() - # One branch per channel. - self.channel_branches = nn.ModuleList([ - nn.Sequential( - nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - nn.Dropout(0.1) - ) for _ in range(num_channels) - ]) - self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) - self.pos_encoder = PositionalEncoding(hidden_dim) - encoder_layers = TransformerEncoderLayer( - d_model=hidden_dim, nhead=4, dim_feedforward=512, - dropout=0.1, activation='gelu', batch_first=True # avoid nested tensor warning - ) - self.transformer = TransformerEncoder(encoder_layers, num_layers=2) - self.attn_pool = nn.Linear(hidden_dim, 1) - self.high_pred = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.GELU(), - nn.Linear(hidden_dim // 2, 1) - ) - self.low_pred = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.GELU(), - nn.Linear(hidden_dim // 2, 1) - ) - def forward(self, x, timeframe_ids): - # x: [batch_size, num_channels, FEATURES_PER_CHANNEL] - batch_size, num_channels, _ = x.shape - channel_outs = [] - for i in range(num_channels): - channel_out = self.channel_branches[i](x[:, i, :]) - channel_outs.append(channel_out) - stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] - tf_embeds = self.timeframe_embed(timeframe_ids) # [num_channels, hidden] - stacked = stacked + tf_embeds.unsqueeze(0) # add embeddings (broadcast along batch) - transformer_out = self.transformer(stacked) - attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) - aggregated = (transformer_out * attn_weights).sum(dim=1) - return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() - -# --- Technical Indicator Helpers --- -def compute_sma(candles_list, index, period=10): - start = max(0, index - period + 1) - values = [candle["close"] for candle in candles_list[start:index+1]] - return sum(values)/len(values) if values else 0.0 - -def compute_sma_volume(candles_list, index, period=10): - start = max(0, index - period + 1) - values = [candle["volume"] for candle in candles_list[start:index+1]] - return sum(values)/len(values) if values else 0.0 - -def get_aligned_candle_with_index(candles_list, target_ts): - best_idx = 0 - for i, candle in enumerate(candles_list): - if candle["timestamp"] <= target_ts: - best_idx = i - else: - break - return best_idx, candles_list[best_idx] - -def get_features_for_tf(candles_list, index, period=10): - candle = candles_list[index] - f_open = candle["open"] - f_high = candle["high"] - f_low = candle["low"] - f_close = candle["close"] - f_volume = candle["volume"] - sma_close = compute_sma(candles_list, index, period) - sma_volume= compute_sma_volume(candles_list, index, period) - return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] - -# --- Caching & Checkpoint Functions --- -def load_candles_cache(filename): - if os.path.exists(filename): - try: - with open(filename, "r") as f: - data = json.load(f) - print(f"Loaded cached data from {filename}.") - return data - except Exception as e: - print("Error reading cache file:", e) - return {} - -def save_candles_cache(filename, candles_dict): - try: - with open(filename, "w") as f: - json.dump(candles_dict, f) - except Exception as e: - print("Error saving cache file:", e) +# ------------------------------- +# Checkpoint Functions +# ------------------------------- +LAST_DIR = os.path.join("models", "last") +BEST_DIR = os.path.join("models", "best") +os.makedirs(LAST_DIR, exist_ok=True) +os.makedirs(BEST_DIR, exist_ok=True) def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) @@ -267,231 +210,103 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): model.load_state_dict(old_state, strict=False) return checkpoint -# --- Function for Manual Trade Override --- -def manual_trade(env): - """ - When no sufficient action is taken by the model, use a fallback: - Scan the remaining window for the global maximum and minimum. - If the maximum occurs before the minimum, simulate a short trade; - otherwise simulate a long trade. - Use the candle "close" prices to compute trade reward. - """ - current_index = env.current_index - if current_index >= len(env.candle_window) - 1: - env.current_index = len(env.candle_window) - 1 - return - max_val = -float('inf') - min_val = float('inf') - i_max = current_index - i_min = current_index - for j in range(current_index + 1, len(env.candle_window)): - high_j = env.candle_window[j]["high"] - low_j = env.candle_window[j]["low"] - if high_j > max_val: - max_val = high_j - i_max = j - if low_j < min_val: - min_val = low_j - i_min = j - if i_max < i_min: - entry_price = env.candle_window[current_index]["close"] - exit_price = env.candle_window[i_min]["close"] - reward = entry_price - exit_price - trade = { - "entry_index": current_index, - "entry_price": entry_price, - "exit_index": i_min, - "exit_price": exit_price, - "pnl": reward - } - else: - entry_price = env.candle_window[current_index]["close"] - exit_price = env.candle_window[i_max]["close"] - reward = exit_price - entry_price - trade = { - "entry_index": current_index, - "entry_price": entry_price, - "exit_index": i_max, - "exit_price": exit_price, - "pnl": reward - } - env.trade_history.append(trade) - env.current_index = trade["exit_index"] +# ------------------------------- +# Positional Encoding and Transformer-Based Model +# ------------------------------- +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + def forward(self, x): + x = x + self.pe[:x.size(0)] + return self.dropout(x) -# --- Simulation for 1s Data Using Local Extrema --- -def simulate_trades_1s(env): - """ - When the main timeframe is 1s, scan the entire remaining window to detect local extrema. - If at least two extrema are found, pair consecutive extrema as trades. - Use the candle "close" prices for trade reward calculation. - If too few extrema are found, fallback to manual_trade. - """ - n = len(env.candle_window) - extrema = [] - for i in range(env.current_index, n): - if i == env.current_index or i == n-1: - extrema.append(i) - else: - prev = env.candle_window[i-1]["close"] - curr = env.candle_window[i]["close"] - nex = env.candle_window[i+1]["close"] - if curr < prev and curr < nex: - extrema.append(i) - elif curr > prev and curr > nex: - extrema.append(i) - if len(extrema) < 2: - manual_trade(env) - return - for j in range(len(extrema)-1): - entry_idx = extrema[j] - exit_idx = extrema[j+1] - entry_price = env.candle_window[entry_idx]["close"] - exit_price = env.candle_window[exit_idx]["close"] - if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]: - reward = exit_price - entry_price - else: - reward = entry_price - exit_price - trade = { - "entry_index": entry_idx, - "entry_price": entry_price, - "exit_index": exit_idx, - "exit_price": exit_price, - "pnl": reward - } - env.trade_history.append(trade) - env.current_index = n - 1 +class TradingModel(nn.Module): + def __init__(self, num_channels, num_timeframes, hidden_dim=128): + super().__init__() + # One branch per channel. + self.channel_branches = nn.ModuleList([ + nn.Sequential( + nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), + nn.Dropout(0.1) + ) for _ in range(num_channels) + ]) + self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) + self.pos_encoder = PositionalEncoding(hidden_dim) + encoder_layers = TransformerEncoderLayer( + d_model=hidden_dim, nhead=4, dim_feedforward=512, + dropout=0.1, activation='gelu', batch_first=True + ) + self.transformer = TransformerEncoder(encoder_layers, num_layers=2) + self.attn_pool = nn.Linear(hidden_dim, 1) + self.high_pred = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.GELU(), + nn.Linear(hidden_dim // 2, 1) + ) + self.low_pred = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.GELU(), + nn.Linear(hidden_dim // 2, 1) + ) + def forward(self, x, timeframe_ids): + batch_size, num_channels, _ = x.shape + channel_outs = [] + for i in range(num_channels): + channel_out = self.channel_branches[i](x[:, i, :]) + channel_outs.append(channel_out) + stacked = torch.stack(channel_outs, dim=1) + tf_embeds = self.timeframe_embed(timeframe_ids) + stacked = stacked + tf_embeds.unsqueeze(0) + transformer_out = self.transformer(stacked) + attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) + aggregated = (transformer_out * attn_weights).sum(dim=1) + return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() -# --- General Simulation of Trades --- -def simulate_trades(model, env, device, args): - """ - Simulate trades over the current sliding window. - If the main timeframe is 1s, use local extrema detection. - Otherwise, check if the model's predicted potentials exceed the threshold. - Use manual_trade if the model's signal is too weak. - """ - if args.main_tf == "1s": - simulate_trades_1s(env) - return - env.reset() - while True: - i = env.current_index - if i >= len(env.candle_window) - 1: - break - state = env.get_state(i) - current_open = env.candle_window[i]["open"] - state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) - timeframe_ids = torch.arange(state.shape[0]).to(device) - pred_high, pred_low = model(state_tensor, timeframe_ids) - pred_high = pred_high.item() - pred_low = pred_low.item() - if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: - if (pred_high - current_open) >= (current_open - pred_low): - action = 2 # BUY - else: - action = 0 # SELL - _, _, _, done, _, _ = env.step(action) - else: - manual_trade(env) - if env.current_index >= len(env.candle_window) - 1: +# ------------------------------- +# Technical Indicator Helpers +# ------------------------------- +def compute_sma(candles_list, index, period=10): + start = max(0, index - period + 1) + values = [candle["close"] for candle in candles_list[start:index+1]] + return sum(values)/len(values) if values else 0.0 + +def compute_sma_volume(candles_list, index, period=10): + start = max(0, index - period + 1) + values = [candle["volume"] for candle in candles_list[start:index+1]] + return sum(values)/len(values) if values else 0.0 + +def get_aligned_candle_with_index(candles_list, target_ts): + best_idx = 0 + for i, candle in enumerate(candles_list): + if candle["timestamp"] <= target_ts: + best_idx = i + else: break + return best_idx, candles_list[best_idx] -# --- Live HTML Chart Update (with Volume and Loss) --- -def update_live_html(candles, trade_history, epoch, loss, total_pnl): - """ - Generate an HTML page with a live chart. - The chart displays price (line) and volume (bars on a secondary y-axis), - and includes trade markers with dotted connecting lines. - The title shows the epoch, loss, and total PnL. - The page auto-refreshes every 1 second. - """ - from io import BytesIO - import base64 - fig, ax = plt.subplots(figsize=(12, 6)) - update_live_chart(ax, candles, trade_history) - epoch_pnl = sum(trade["pnl"] for trade in trade_history) - - ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") - buf = BytesIO() - fig.savefig(buf, format='png') - plt.close(fig) - buf.seek(0) - image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') - html_content = f""" - - - - - - Live Trading Chart - Epoch {epoch} - - - -
-

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}

- Live Chart -
- - -""" - with open("live_chart.html", "w") as f: - f.write(html_content) - print("Updated live_chart.html.") +def get_features_for_tf(candles_list, index, period=10): + candle = candles_list[index] + f_open = candle["open"] + f_high = candle["high"] + f_low = candle["low"] + f_close = candle["close"] + f_volume = candle["volume"] + sma_close = compute_sma(candles_list, index, period) + sma_volume= compute_sma_volume(candles_list, index, period) + return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] -# --- Chart Drawing Helpers (with Volume and Date+Time) --- -def update_live_chart(ax, candles, trade_history): - """ - Plot the price chart with actual timestamps (date and time in short format) - and volume on a secondary y-axis. - Mark trade entry (green) and exit (red) points, with dotted lines connecting them. - """ - ax.clear() - times = [convert_timestamp(candle["timestamp"]) for candle in candles] - close_prices = [candle["close"] for candle in candles] - ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) - ax.set_xlabel("Time") - ax.set_ylabel("Price") - ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) - ax2 = ax.twinx() - volumes = [candle["volume"] for candle in candles] - if len(times) > 1: - times_num = mdates.date2num(times) - bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 - else: - bar_width = 0.01 - ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") - ax2.set_ylabel("Volume") - for trade in trade_history: - entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) - exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) - in_price = trade["entry_price"] - out_price = trade["exit_price"] - ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") - ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") - ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") - lines, labels = ax.get_legend_handles_labels() - lines2, labels2 = ax2.get_legend_handles_labels() - ax.legend(lines + lines2, labels + labels2) - ax.grid(True) - fig = ax.get_figure() - fig.autofmt_xdate() - -# --- Backtest Environment with Sliding Window and Order Info --- +# ------------------------------- +# Backtest Environment Class +# ------------------------------- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): self.candles_dict = candles_dict # full candles dict across timeframes @@ -504,7 +319,7 @@ class BacktestEnvironment: self.reset() def reset(self): self.start_index = random.randint(0, len(self.full_candles) - self.window_size) - self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] + self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] self.current_index = 0 self.trade_history = [] self.position = None @@ -550,7 +365,7 @@ class BacktestEnvironment: self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL (close trade) - exit_price = next_candle["close"] # use close price + exit_price = next_candle["close"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], @@ -567,10 +382,11 @@ class BacktestEnvironment: actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low -# --- Enhanced Training Loop --- +# ------------------------------- +# Enhanced Training Loop +# ------------------------------- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade - # Load any saved total PnL from training cache: training_cache = load_training_cache(TRAINING_CACHE_FILE) total_pnl = training_cache.get("total_pnl", 0.0) for epoch in range(start_epoch, args.epochs): @@ -601,7 +417,6 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s loss_accum += loss.item() scheduler.step() epoch_loss = loss_accum / steps - # If no trades occurred during the epoch, multiply the loss by 3. if len(env.trade_history) == 0: epoch_loss *= 3 epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) @@ -610,11 +425,12 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s save_checkpoint(model, optimizer, epoch, loss_accum) simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl) - # Update training cache with the new total PnL: training_cache["total_pnl"] = total_pnl save_training_cache(TRAINING_CACHE_FILE, training_cache) -# --- Live Plotting (for Live Mode) --- +# ------------------------------- +# Live Plotting (for Live Mode) +# ------------------------------- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) @@ -623,7 +439,360 @@ def live_preview_loop(candles, env): plt.draw() plt.pause(1) -# --- Argument Parsing --- +# ------------------------------- +# Live HTML Chart Update (with Volume and Loss) +# ------------------------------- +def update_live_html(candles, trade_history, epoch, loss, total_pnl): + from io import BytesIO + import base64 + fig, ax = plt.subplots(figsize=(12, 6)) + update_live_chart(ax, candles, trade_history) + epoch_pnl = sum(trade["pnl"] for trade in trade_history) + ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}") + buf = BytesIO() + fig.savefig(buf, format='png') + plt.close(fig) + buf.seek(0) + image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') + html_content = f""" + + + + + + Live Trading Chart - Epoch {epoch} + + + +
+

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}

+ Live Chart +
+ + +""" + with open("live_chart.html", "w") as f: + f.write(html_content) + print("Updated live_chart.html.") + +# ------------------------------- +# Chart Drawing Helpers (with Volume and Date+Time) +# ------------------------------- +def update_live_chart(ax, candles, trade_history): + ax.clear() + times = [convert_timestamp(candle["timestamp"]) for candle in candles] + close_prices = [candle["close"] for candle in candles] + ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) + ax.set_xlabel("Time") + ax.set_ylabel("Price") + ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) + ax2 = ax.twinx() + volumes = [candle["volume"] for candle in candles] + if len(times) > 1: + times_num = mdates.date2num(times) + bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 + else: + bar_width = 0.01 + ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") + ax2.set_ylabel("Volume") + for trade in trade_history: + entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) + exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) + in_price = trade["entry_price"] + out_price = trade["exit_price"] + ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") + ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") + ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax.legend(lines + lines2, labels + labels2) + ax.grid(True) + fig = ax.get_figure() + fig.autofmt_xdate() + +# ------------------------------- +# Backtest Environment Class +# ------------------------------- +class BacktestEnvironment: + def __init__(self, candles_dict, base_tf, timeframes, window_size=None): + self.candles_dict = candles_dict + self.base_tf = base_tf + self.timeframes = timeframes + self.full_candles = candles_dict[base_tf] + if window_size is None: + window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) + self.window_size = window_size + self.reset() + def reset(self): + self.start_index = random.randint(0, len(self.full_candles)-self.window_size) + self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] + self.current_index = 0 + self.trade_history = [] + self.position = None + return self.get_state(self.current_index) + def __len__(self): + return self.window_size + def get_order_features(self, index): + candle = self.candle_window[index] + if self.position is None: + return [0.0] * FEATURES_PER_CHANNEL + else: + flag = 1.0 + diff = (candle["open"] - self.position["entry_price"]) / candle["open"] + return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) + def get_state(self, index): + state_features = [] + base_ts = self.candle_window[index]["timestamp"] + for tf in self.timeframes: + if tf == self.base_tf: + candle = self.candle_window[index] + features = get_features_for_tf([candle], 0) + else: + aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) + features = get_features_for_tf(self.candles_dict[tf], aligned_idx) + state_features.append(features) + order_features = self.get_order_features(index) + state_features.append(order_features) + for _ in range(NUM_INDICATORS): + state_features.append([0.0]*FEATURES_PER_CHANNEL) + return np.array(state_features, dtype=np.float32) + def step(self, action): + base = self.candle_window + if self.current_index >= len(base)-1: + current_state = self.get_state(self.current_index) + return current_state, 0.0, None, True, 0.0, 0.0 + current_state = self.get_state(self.current_index) + next_index = self.current_index + 1 + next_state = self.get_state(next_index) + next_candle = base[next_index] + reward = 0.0 + if self.position is None: + if action == 2: # BUY + self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} + else: + if action == 0: # SELL + exit_price = next_candle["close"] + reward = exit_price - self.position["entry_price"] + trade = { + "entry_index": self.position["entry_index"], + "entry_price": self.position["entry_price"], + "exit_index": next_index, + "exit_price": exit_price, + "pnl": reward + } + self.trade_history.append(trade) + self.position = None + self.current_index = next_index + done = (self.current_index >= len(base)-1) + actual_high = next_candle["high"] + actual_low = next_candle["low"] + return current_state, reward, next_state, done, actual_high, actual_low + +# ------------------------------- +# Enhanced Training Loop +# ------------------------------- +def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): + lambda_trade = args.lambda_trade + training_cache = load_training_cache(TRAINING_CACHE_FILE) + total_pnl = training_cache.get("total_pnl", 0.0) + for epoch in range(start_epoch, args.epochs): + env.reset() + loss_accum = 0.0 + steps = len(env) - 1 + for i in range(steps): + state = env.get_state(i) + current_open = env.candle_window[i]["open"] + actual_high = env.candle_window[i+1]["high"] + actual_low = env.candle_window[i+1]["low"] + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + timeframe_ids = torch.arange(state.shape[0]).to(device) + pred_high, pred_low = model(state_tensor, timeframe_ids) + L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ + torch.abs(pred_low - torch.tensor(actual_low, device=device)) + profit_buy = pred_high - current_open + profit_sell = current_open - pred_low + L_trade = - torch.max(profit_buy, profit_sell) + current_open_tensor = torch.tensor(current_open, device=device) + signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) + penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) + loss = L_pred + lambda_trade * L_trade + penalty_term + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + loss_accum += loss.item() + scheduler.step() + epoch_loss = loss_accum / steps + if len(env.trade_history) == 0: + epoch_loss *= 3 + epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) + total_pnl += epoch_pnl + print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") + save_checkpoint(model, optimizer, epoch, loss_accum) + simulate_trades(model, env, device, args) + update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl) + training_cache["total_pnl"] = total_pnl + save_training_cache(TRAINING_CACHE_FILE, training_cache) + +# ------------------------------- +# Live Plotting (for Live Mode) +# ------------------------------- +def live_preview_loop(candles, env): + plt.ion() + fig, ax = plt.subplots(figsize=(12, 6)) + while True: + update_live_chart(ax, candles, env.trade_history) + plt.draw() + plt.pause(1) + +# ------------------------------- +# Live HTML Chart Update (with Volume and Loss) +# ------------------------------- +def update_live_html(candles, trade_history, epoch, loss, total_pnl): + from io import BytesIO + import base64 + fig, ax = plt.subplots(figsize=(12, 6)) + update_live_chart(ax, candles, trade_history) + epoch_pnl = sum(trade["pnl"] for trade in trade_history) + ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}") + buf = BytesIO() + fig.savefig(buf, format='png') + plt.close(fig) + buf.seek(0) + image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') + html_content = f""" + + + + + + Live Trading Chart - Epoch {epoch} + + + +
+

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}

+ Live Chart +
+ + +""" + with open("live_chart.html", "w") as f: + f.write(html_content) + print("Updated live_chart.html.") + +# ------------------------------- +# Chart Drawing Helpers (with Volume and Date+Time) +# ------------------------------- +def update_live_chart(ax, candles, trade_history): + ax.clear() + times = [convert_timestamp(candle["timestamp"]) for candle in candles] + close_prices = [candle["close"] for candle in candles] + ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) + ax.set_xlabel("Time") + ax.set_ylabel("Price") + ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) + ax2 = ax.twinx() + volumes = [candle["volume"] for candle in candles] + if len(times) > 1: + times_num = mdates.date2num(times) + bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 + else: + bar_width = 0.01 + ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") + ax2.set_ylabel("Volume") + for trade in trade_history: + entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) + exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) + in_price = trade["entry_price"] + out_price = trade["exit_price"] + ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") + ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") + ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax.legend(lines + lines2, labels + labels2) + ax.grid(True) + fig = ax.get_figure() + fig.autofmt_xdate() + +# ------------------------------- +# Global Constants for Features +# ------------------------------- +NUM_INDICATORS = 20 +FEATURES_PER_CHANNEL = 7 +ORDER_CHANNELS = 1 + +# ------------------------------- +# Backtest Environment with Sliding Window and Order Info (Already Defined Above) +# [See BacktestEnvironment class above] +# ------------------------------- + +# ------------------------------- +# General Simulation of Trades Function +# ------------------------------- +def simulate_trades(model, env, device, args): + if args.main_tf == "1s": + simulate_trades_1s(env) + return + env.reset() + while True: + i = env.current_index + if i >= len(env.candle_window) - 1: + break + state = env.get_state(i) + current_open = env.candle_window[i]["open"] + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + timeframe_ids = torch.arange(state.shape[0]).to(device) + pred_high, pred_low = model(state_tensor, timeframe_ids) + pred_high = pred_high.item() + pred_low = pred_low.item() + if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: + if (pred_high - current_open) >= (current_open - pred_low): + action = 2 + else: + action = 0 + _, _, _, done, _, _ = env.step(action) + else: + manual_trade(env) + if env.current_index >= len(env.candle_window) - 1: + break + +# ------------------------------- +# Argument Parsing +# ------------------------------- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') @@ -638,35 +807,55 @@ def parse_args(): parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--main_tf', type=str, default='1m', help="Desired main timeframe to focus on (e.g., '1s' or '1m').") + parser.add_argument('--fetch', action='store_true', help="Fetch fresh data from exchange on start.") + parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.") return parser.parse_args() def random_action(): return random.randint(0, 2) -# --- Main Function --- +# ------------------------------- +# Main Function +# ------------------------------- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) - # Load cached candles. - candles_dict = load_candles_cache(CACHE_FILE) - # If the desired main timeframe is 1s, attempt to fetch real 1s data. - if args.main_tf == "1s": - real_1s_data = fetch_real_1s_data() - if real_1s_data: - candles_dict["1s"] = real_1s_data - if not candles_dict: - print("No historical candle data available for backtesting.") - return - # Define desired timeframes list; if available, include "1s". + + # If --fetch flag is provided, top-up cached OHLCV data with fresh data from exchange. + if args.fetch: + import ccxt.async_support as ccxt + exchange = ccxt.binance({'enableRateLimit': True}) + now_ms = int(time.time()*1000) + # Determine default "since" time based on cache. + cached = load_candles_cache(CACHE_FILE) + if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0: + last_ts = cached[args.main_tf][-1]['timestamp'] + since = last_ts + 1 + else: + # Default: fetch candles from the last 2 days. + since = now_ms - 2*24*60*60*1000 + # Top-up data for the main timeframe. + print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...") + fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms) + # Update cache (for simplicity, we store only the main timeframe here). + candles_dict = {args.main_tf: fresh_candles} + save_candles_cache(CACHE_FILE, candles_dict) + await exchange.close() + else: + candles_dict = load_candles_cache(CACHE_FILE) + if not candles_dict: + print("No cached data available. Run with --fetch to load fresh data from the exchange.") + return + + # Define desired timeframes list. default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] timeframes = [tf for tf in default_timeframes if tf in candles_dict] if args.main_tf not in timeframes: print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}") return - base_tf = args.main_tf # Set the main timeframe as the base for the environment. + base_tf = args.main_tf hidden_dim = 128 - # Total channels: number of timeframes + 1 order channel + NUM_INDICATORS. total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS model = TradingModel(total_channels, len(timeframes)).to(device) @@ -700,7 +889,6 @@ async def main(): else: print("No valid optimizer state found; using fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) - elif args.mode == 'live': load_best_checkpoint(model) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)