From 615579d456c38ce4b4ec6e0b9ca4efaa14e79426 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 22:10:24 +0200 Subject: [PATCH] better train algo --- crypto/brian/index-deep-new.py | 147 +++++++++++++++++++++------------ 1 file changed, 92 insertions(+), 55 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index e6f8b70..5b68e5b 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -18,6 +18,7 @@ from datetime import datetime import matplotlib.pyplot as plt import math from torch.nn import TransformerEncoder, TransformerEncoderLayer +import matplotlib.dates as mdates from dotenv import load_dotenv load_dotenv() @@ -29,9 +30,12 @@ os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # --- Constants --- -NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] -NUM_INDICATORS = 20 # e.g., 20 technical indicators -FEATURES_PER_CHANNEL = 7 # e.g., [open, high, low, close, volume, sma_close, sma_volume] +NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] +NUM_INDICATORS = 20 # e.g., 20 technical indicators +# Each channel input will have 7 features. +FEATURES_PER_CHANNEL = 7 +# We add one extra channel for order information. +ORDER_CHANNELS = 1 # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): @@ -52,7 +56,7 @@ class PositionalEncoding(nn.Module): class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() - # One branch per channel + # Create one branch per channel. self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), @@ -61,6 +65,7 @@ class TradingModel(nn.Module): nn.Dropout(0.1) ) for _ in range(num_channels) ]) + # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( @@ -86,8 +91,8 @@ class TradingModel(nn.Module): for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) - stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] - stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] + stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] + stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) @@ -151,7 +156,7 @@ def maintain_checkpoint_directory(directory, max_files=10): if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) - for f in full_paths[:len(files)-max_files]: + for f in full_paths[:len(files) - max_files]: os.remove(f) def get_best_models(directory): @@ -170,10 +175,10 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ - "epoch": epoch, - "loss": loss, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict() + "epoch": epoch, + "loss": loss, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) @@ -211,15 +216,17 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ - Generate a chart image with buy/sell markers and dotted lines between entry and exit, - then embed it in an auto-refreshing HTML page. + Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL. + The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes. """ from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) - ax.set_title(f"Live Trading Chart - Epoch {epoch}") + # Compute cumulative epoch PnL. + epoch_pnl = sum(trade["pnl"] for trade in trade_history) + ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) @@ -252,7 +259,7 @@ def update_live_html(candles, trade_history, epoch):
-

Live Trading Chart - Epoch {epoch}

+

Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}

Live Chart
@@ -265,42 +272,51 @@ def update_live_html(candles, trade_history, epoch): # --- Chart Drawing Helpers --- def update_live_chart(ax, candles, trade_history): """ - Draw the price chart with close prices and mark BUY (green) and SELL (red) actions. + Plot the price chart using actual timestamps on the x-axis. + Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit. """ ax.clear() + # Convert timestamps to datetime objects. + times = [datetime.fromtimestamp(candle["timestamp"]) for candle in candles] close_prices = [candle["close"] for candle in candles] - x = list(range(len(close_prices))) - ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1) + ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) + # Format x-axis date labels. + ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) + # Calculate epoch PnL. + epoch_pnl = sum(trade["pnl"] for trade in trade_history) + # Plot each trade. buy_label_added = False sell_label_added = False for trade in trade_history: - in_idx = trade["entry_index"] - out_idx = trade["exit_index"] + entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"]) + exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"]) in_price = trade["entry_price"] out_price = trade["exit_price"] if not buy_label_added: - ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY") + ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") buy_label_added = True else: - ax.plot(in_idx, in_price, marker="^", color="green", markersize=10) + ax.plot(entry_time, in_price, marker="^", color="green", markersize=10) if not sell_label_added: - ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL") + ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") sell_label_added = True else: - ax.plot(out_idx, out_price, marker="v", color="red", markersize=10) - ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue") - ax.set_xlabel("Candle Index") + ax.plot(exit_time, out_price, marker="v", color="red", markersize=10) + ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") + ax.set_xlabel("Time") ax.set_ylabel("Price") ax.legend() ax.grid(True) + fig = ax.get_figure() + fig.autofmt_xdate() # --- Simulation of Trades for Visualization --- def simulate_trades(model, env, device, args): """ - Run a complete simulation on the current sliding window using a decision rule based on model outputs. - This simulation (which updates env.trade_history) is used only for visualization. + Run a simulation on the current sliding window using the model's outputs and a decision rule. + This simulation updates env.trade_history and is used for visualization only. """ - env.reset() # resets the sliding window and index + env.reset() # resets the window and index while True: i = env.current_index state = env.get_state(i) @@ -310,7 +326,7 @@ def simulate_trades(model, env, device, args): pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() - # Decision rule: if upward move larger than downward and above threshold, BUY; if downward is larger, SELL; else HOLD. + # Simple decision rule based on predicted move. if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: action = 2 # BUY elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: @@ -321,23 +337,22 @@ def simulate_trades(model, env, device, args): if done: break -# --- Backtest Environment with Sliding Window --- +# --- Backtest Environment with Sliding Window and Order Info --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): - self.candles_dict = candles_dict # full candles dict for all timeframes + self.candles_dict = candles_dict # full candles dict across timeframes self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size - self.hint_penalty = 0.001 # not used in the revised loss below self.reset() def reset(self): - # Pick a random sliding window from the full dataset. + # Randomly select a sliding window from the full dataset. self.start_index = random.randint(0, len(self.full_candles) - self.window_size) - self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] + self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.current_index = 0 self.trade_history = [] self.position = None @@ -346,7 +361,29 @@ class BacktestEnvironment: def __len__(self): return self.window_size + def get_order_features(self, index): + """ + Returns a list of 7 features for the order channel. + If an order is open, the first element is 1.0 and the second is the normalized difference: + (current open - entry_price) / current open. + Otherwise, returns zeros. + """ + candle = self.candle_window[index] + if self.position is None: + return [0.0] * FEATURES_PER_CHANNEL + else: + flag = 1.0 + diff = (candle["open"] - self.position["entry_price"]) / candle["open"] + return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) + def get_state(self, index): + """ + Build state features from: + - For each timeframe: features from the aligned candle. + - One extra channel: current order information. + - NUM_INDICATORS channels of zeros. + Each channel is a vector of length FEATURES_PER_CHANNEL. + """ state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: @@ -357,15 +394,19 @@ class BacktestEnvironment: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) + # Append order channel. + order_features = self.get_order_features(index) + state_features.append(order_features) + # Append technical indicator channels. for _ in range(NUM_INDICATORS): - state_features.append([0.0]*FEATURES_PER_CHANNEL) + state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): """ - Discrete simulation step. - - Action: 0 (SELL), 1 (HOLD), 2 (BUY). - - Trades are recorded when a BUY is followed by a SELL. + Execute one step in the environment: + - action: 0 => SELL, 1 => HOLD, 2 => BUY. + - Trades recorded when a BUY is followed by a SELL. """ base = self.candle_window if self.current_index >= len(base) - 1: @@ -378,7 +419,6 @@ class BacktestEnvironment: next_candle = base[next_index] reward = 0.0 - # Simple trading logic (only one position allowed at a time) if self.position is None: if action == 2: # BUY signal: enter at next open. self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} @@ -404,29 +444,25 @@ class BacktestEnvironment: # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): - # Weighting factor for trade surrogate loss. - lambda_trade = 1.0 + lambda_trade = args.lambda_trade # Weight for the surrogate profit loss. for epoch in range(start_epoch, args.epochs): - # Reset sliding window for each epoch. - env.reset() + env.reset() # Resets the sliding window. loss_accum = 0.0 - steps = len(env) - 1 # we use pairs of consecutive candles + steps = len(env) - 1 # We assume steps over consecutive candle pairs. for i in range(steps): state = env.get_state(i) current_open = env.candle_window[i]["open"] - # Next candle's actual values serve as targets. actual_high = env.candle_window[i+1]["high"] actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) - # Compute prediction loss (L1) + # Prediction loss (L1 error). L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) - # Compute surrogate profit (differentiable estimate) + # Surrogate profit loss: profit_buy = pred_high - current_open # potential long gain profit_sell = current_open - pred_low # potential short gain - # Here we reward a higher potential move by subtracting it. L_trade = - torch.max(profit_buy, profit_sell) loss = L_pred + lambda_trade * L_trade optimizer.zero_grad() @@ -438,7 +474,6 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s epoch_loss = loss_accum / steps print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") save_checkpoint(model, optimizer, epoch, loss_accum) - # Update the trade simulation (for visualization) using the current model on the same window. simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1) @@ -458,10 +493,13 @@ def parse_args(): parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") - parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.") + parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") return parser.parse_args() +def random_action(): + return random.randint(0, 2) + # --- Main Function --- async def main(): args = parse_args() @@ -469,7 +507,8 @@ async def main(): print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] hidden_dim = 128 - total_channels = NUM_TIMEFRAMES + NUM_INDICATORS + # Total channels: NUM_TIMEFRAMES + 1 (order info) + NUM_INDICATORS. + total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) if args.mode == 'train': @@ -478,7 +517,6 @@ async def main(): print("No historical candle data available for backtesting.") return base_tf = "1m" - # Use a sliding window of up to 100 candles (if available) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None @@ -513,7 +551,6 @@ async def main(): preview_thread.start() print("Starting live trading loop. (Using model-based decision rule.)") while True: - # In live mode, we use the simulation decision rule. state = env.get_state(env.current_index) current_open = env.candle_window[env.current_index]["open"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) @@ -535,7 +572,7 @@ async def main(): elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") - # Inference logic can use a similar decision rule as in live mode. + # Your inference logic goes here. else: print("Invalid mode specified.")