better traing algo to force more trades
This commit is contained in:
parent
615579d456
commit
d2686b31b7
@ -216,8 +216,9 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
|||||||
# --- Live HTML Chart Update ---
|
# --- Live HTML Chart Update ---
|
||||||
def update_live_html(candles, trade_history, epoch):
|
def update_live_html(candles, trade_history, epoch):
|
||||||
"""
|
"""
|
||||||
Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL.
|
Generate a chart image that uses actual timestamps on the x-axis
|
||||||
The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes.
|
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
|
||||||
|
is embedded in an HTML page that auto-refreshes every 10 seconds.
|
||||||
"""
|
"""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
import base64
|
||||||
@ -272,7 +273,7 @@ def update_live_html(candles, trade_history, epoch):
|
|||||||
# --- Chart Drawing Helpers ---
|
# --- Chart Drawing Helpers ---
|
||||||
def update_live_chart(ax, candles, trade_history):
|
def update_live_chart(ax, candles, trade_history):
|
||||||
"""
|
"""
|
||||||
Plot the price chart using actual timestamps on the x-axis.
|
Plot the price chart with actual timestamps on the x-axis.
|
||||||
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
|
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
|
||||||
"""
|
"""
|
||||||
ax.clear()
|
ax.clear()
|
||||||
@ -282,8 +283,6 @@ def update_live_chart(ax, candles, trade_history):
|
|||||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||||
# Format x-axis date labels.
|
# Format x-axis date labels.
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
||||||
# Calculate epoch PnL.
|
|
||||||
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
|
|
||||||
# Plot each trade.
|
# Plot each trade.
|
||||||
buy_label_added = False
|
buy_label_added = False
|
||||||
sell_label_added = False
|
sell_label_added = False
|
||||||
@ -406,7 +405,7 @@ class BacktestEnvironment:
|
|||||||
"""
|
"""
|
||||||
Execute one step in the environment:
|
Execute one step in the environment:
|
||||||
- action: 0 => SELL, 1 => HOLD, 2 => BUY.
|
- action: 0 => SELL, 1 => HOLD, 2 => BUY.
|
||||||
- Trades recorded when a BUY is followed by a SELL.
|
- Trades are recorded when a BUY is followed by a SELL.
|
||||||
"""
|
"""
|
||||||
base = self.candle_window
|
base = self.candle_window
|
||||||
if self.current_index >= len(base) - 1:
|
if self.current_index >= len(base) - 1:
|
||||||
@ -460,11 +459,15 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
|||||||
# Prediction loss (L1 error).
|
# Prediction loss (L1 error).
|
||||||
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
|
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
|
||||||
torch.abs(pred_low - torch.tensor(actual_low, device=device))
|
torch.abs(pred_low - torch.tensor(actual_low, device=device))
|
||||||
# Surrogate profit loss:
|
# Surrogate profit loss.
|
||||||
profit_buy = pred_high - current_open # potential long gain
|
profit_buy = pred_high - current_open # potential long gain
|
||||||
profit_sell = current_open - pred_low # potential short gain
|
profit_sell = current_open - pred_low # potential short gain
|
||||||
L_trade = - torch.max(profit_buy, profit_sell)
|
L_trade = - torch.max(profit_buy, profit_sell)
|
||||||
loss = L_pred + lambda_trade * L_trade
|
# Additional penalty if no strong signal is produced.
|
||||||
|
current_open_tensor = torch.tensor(current_open, device=device)
|
||||||
|
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
|
||||||
|
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
|
||||||
|
loss = L_pred + lambda_trade * L_trade + penalty_term
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
@ -494,6 +497,7 @@ def parse_args():
|
|||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
|
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
|
||||||
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
|
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
|
||||||
|
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).")
|
||||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user