better traing algo to force more trades

This commit is contained in:
Dobromir Popov 2025-02-04 22:12:55 +02:00
parent 615579d456
commit d2686b31b7

View File

@ -216,8 +216,9 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL.
The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes.
Generate a chart image that uses actual timestamps on the x-axis
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
is embedded in an HTML page that auto-refreshes every 10 seconds.
"""
from io import BytesIO
import base64
@ -272,7 +273,7 @@ def update_live_html(candles, trade_history, epoch):
# --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the price chart using actual timestamps on the x-axis.
Plot the price chart with actual timestamps on the x-axis.
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
"""
ax.clear()
@ -282,8 +283,6 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
# Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Calculate epoch PnL.
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
# Plot each trade.
buy_label_added = False
sell_label_added = False
@ -406,7 +405,7 @@ class BacktestEnvironment:
"""
Execute one step in the environment:
- action: 0 => SELL, 1 => HOLD, 2 => BUY.
- Trades recorded when a BUY is followed by a SELL.
- Trades are recorded when a BUY is followed by a SELL.
"""
base = self.candle_window
if self.current_index >= len(base) - 1:
@ -460,11 +459,15 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
# Prediction loss (L1 error).
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device))
# Surrogate profit loss:
# Surrogate profit loss.
profit_buy = pred_high - current_open # potential long gain
profit_sell = current_open - pred_low # potential short gain
L_trade = - torch.max(profit_buy, profit_sell)
loss = L_pred + lambda_trade * L_trade
# Additional penalty if no strong signal is produced.
current_open_tensor = torch.tensor(current_open, device=device)
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
loss = L_pred + lambda_trade * L_trade + penalty_term
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@ -494,6 +497,7 @@ def parse_args():
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args()