better traing algo to force more trades

This commit is contained in:
Dobromir Popov 2025-02-04 22:12:55 +02:00
parent 615579d456
commit d2686b31b7

View File

@ -216,8 +216,9 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Live HTML Chart Update --- # --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch): def update_live_html(candles, trade_history, epoch):
""" """
Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL. Generate a chart image that uses actual timestamps on the x-axis
The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes. and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
is embedded in an HTML page that auto-refreshes every 10 seconds.
""" """
from io import BytesIO from io import BytesIO
import base64 import base64
@ -272,7 +273,7 @@ def update_live_html(candles, trade_history, epoch):
# --- Chart Drawing Helpers --- # --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history): def update_live_chart(ax, candles, trade_history):
""" """
Plot the price chart using actual timestamps on the x-axis. Plot the price chart with actual timestamps on the x-axis.
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit. Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
""" """
ax.clear() ax.clear()
@ -282,8 +283,6 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
# Format x-axis date labels. # Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Calculate epoch PnL.
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
# Plot each trade. # Plot each trade.
buy_label_added = False buy_label_added = False
sell_label_added = False sell_label_added = False
@ -406,7 +405,7 @@ class BacktestEnvironment:
""" """
Execute one step in the environment: Execute one step in the environment:
- action: 0 => SELL, 1 => HOLD, 2 => BUY. - action: 0 => SELL, 1 => HOLD, 2 => BUY.
- Trades recorded when a BUY is followed by a SELL. - Trades are recorded when a BUY is followed by a SELL.
""" """
base = self.candle_window base = self.candle_window
if self.current_index >= len(base) - 1: if self.current_index >= len(base) - 1:
@ -460,11 +459,15 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
# Prediction loss (L1 error). # Prediction loss (L1 error).
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device)) torch.abs(pred_low - torch.tensor(actual_low, device=device))
# Surrogate profit loss: # Surrogate profit loss.
profit_buy = pred_high - current_open # potential long gain profit_buy = pred_high - current_open # potential long gain
profit_sell = current_open - pred_low # potential short gain profit_sell = current_open - pred_low # potential short gain
L_trade = - torch.max(profit_buy, profit_sell) L_trade = - torch.max(profit_buy, profit_sell)
loss = L_pred + lambda_trade * L_trade # Additional penalty if no strong signal is produced.
current_open_tensor = torch.tensor(current_open, device=device)
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
loss = L_pred + lambda_trade * L_trade + penalty_term
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@ -494,6 +497,7 @@ def parse_args():
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args() return parser.parse_args()