From d2686b31b798226b37e27900eddab69b843710d0 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 22:12:55 +0200 Subject: [PATCH] better traing algo to force more trades --- crypto/brian/index-deep-new.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 5b68e5b..94bf4c6 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -216,8 +216,9 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ - Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL. - The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes. + Generate a chart image that uses actual timestamps on the x-axis + and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines) + is embedded in an HTML page that auto-refreshes every 10 seconds. """ from io import BytesIO import base64 @@ -272,7 +273,7 @@ def update_live_html(candles, trade_history, epoch): # --- Chart Drawing Helpers --- def update_live_chart(ax, candles, trade_history): """ - Plot the price chart using actual timestamps on the x-axis. + Plot the price chart with actual timestamps on the x-axis. Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit. """ ax.clear() @@ -282,8 +283,6 @@ def update_live_chart(ax, candles, trade_history): ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) # Format x-axis date labels. ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) - # Calculate epoch PnL. - epoch_pnl = sum(trade["pnl"] for trade in trade_history) # Plot each trade. buy_label_added = False sell_label_added = False @@ -406,7 +405,7 @@ class BacktestEnvironment: """ Execute one step in the environment: - action: 0 => SELL, 1 => HOLD, 2 => BUY. - - Trades recorded when a BUY is followed by a SELL. + - Trades are recorded when a BUY is followed by a SELL. """ base = self.candle_window if self.current_index >= len(base) - 1: @@ -460,11 +459,15 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s # Prediction loss (L1 error). L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) - # Surrogate profit loss: + # Surrogate profit loss. profit_buy = pred_high - current_open # potential long gain profit_sell = current_open - pred_low # potential short gain L_trade = - torch.max(profit_buy, profit_sell) - loss = L_pred + lambda_trade * L_trade + # Additional penalty if no strong signal is produced. + current_open_tensor = torch.tensor(current_open, device=device) + signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) + penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) + loss = L_pred + lambda_trade * L_trade + penalty_term optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) @@ -494,6 +497,7 @@ def parse_args(): parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") + parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") return parser.parse_args()