better traing algo to force more trades
This commit is contained in:
parent
615579d456
commit
d2686b31b7
@ -216,8 +216,9 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
# --- Live HTML Chart Update ---
|
||||
def update_live_html(candles, trade_history, epoch):
|
||||
"""
|
||||
Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL.
|
||||
The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes.
|
||||
Generate a chart image that uses actual timestamps on the x-axis
|
||||
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
|
||||
is embedded in an HTML page that auto-refreshes every 10 seconds.
|
||||
"""
|
||||
from io import BytesIO
|
||||
import base64
|
||||
@ -272,7 +273,7 @@ def update_live_html(candles, trade_history, epoch):
|
||||
# --- Chart Drawing Helpers ---
|
||||
def update_live_chart(ax, candles, trade_history):
|
||||
"""
|
||||
Plot the price chart using actual timestamps on the x-axis.
|
||||
Plot the price chart with actual timestamps on the x-axis.
|
||||
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
|
||||
"""
|
||||
ax.clear()
|
||||
@ -282,8 +283,6 @@ def update_live_chart(ax, candles, trade_history):
|
||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||
# Format x-axis date labels.
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
||||
# Calculate epoch PnL.
|
||||
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
|
||||
# Plot each trade.
|
||||
buy_label_added = False
|
||||
sell_label_added = False
|
||||
@ -406,7 +405,7 @@ class BacktestEnvironment:
|
||||
"""
|
||||
Execute one step in the environment:
|
||||
- action: 0 => SELL, 1 => HOLD, 2 => BUY.
|
||||
- Trades recorded when a BUY is followed by a SELL.
|
||||
- Trades are recorded when a BUY is followed by a SELL.
|
||||
"""
|
||||
base = self.candle_window
|
||||
if self.current_index >= len(base) - 1:
|
||||
@ -460,11 +459,15 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
# Prediction loss (L1 error).
|
||||
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
|
||||
torch.abs(pred_low - torch.tensor(actual_low, device=device))
|
||||
# Surrogate profit loss:
|
||||
# Surrogate profit loss.
|
||||
profit_buy = pred_high - current_open # potential long gain
|
||||
profit_sell = current_open - pred_low # potential short gain
|
||||
L_trade = - torch.max(profit_buy, profit_sell)
|
||||
loss = L_pred + lambda_trade * L_trade
|
||||
# Additional penalty if no strong signal is produced.
|
||||
current_open_tensor = torch.tensor(current_open, device=device)
|
||||
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
|
||||
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
|
||||
loss = L_pred + lambda_trade * L_trade + penalty_term
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
@ -494,6 +497,7 @@ def parse_args():
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
|
||||
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
|
||||
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).")
|
||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user