more strong model response

This commit is contained in:
Dobromir Popov 2025-02-04 22:27:33 +02:00
parent 907468239a
commit 75c4d6602a
2 changed files with 14 additions and 14 deletions

View File

@ -327,6 +327,8 @@ def update_live_chart(ax, candles, trade_history):
def simulate_trades(model, env, device, args): def simulate_trades(model, env, device, args):
""" """
Run a simulation on the current sliding window using the model's outputs and a decision rule. Run a simulation on the current sliding window using the model's outputs and a decision rule.
Here we force the simulation to always take an action by comparing the predicted potentials,
ensuring that the model is forced to trade (either BUY or SELL) rather than HOLD.
This simulation updates env.trade_history and is used for visualization only. This simulation updates env.trade_history and is used for visualization only.
""" """
env.reset() # resets the window and index env.reset() # resets the window and index
@ -339,12 +341,11 @@ def simulate_trades(model, env, device, args):
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() pred_high = pred_high.item()
pred_low = pred_low.item() pred_low = pred_low.item()
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: # Force a trade: choose BUY if predicted up-move is higher (or equal), else SELL.
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY action = 2 # BUY
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
action = 0 # SELL
else: else:
action = 1 # HOLD action = 0 # SELL
_, _, _, done, _, _ = env.step(action) _, _, _, done, _, _ = env.step(action)
if done: if done:
break break
@ -481,9 +482,9 @@ def parse_args():
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade (used in loss).")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken.") parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args() return parser.parse_args()
@ -545,7 +546,7 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (Using model-based decision rule.)") print("Starting live trading loop. (Forcing trade actions based on highest potential.)")
while True: while True:
state = env.get_state(env.current_index) state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"] current_open = env.candle_window[env.current_index]["open"]
@ -554,12 +555,11 @@ async def main():
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() pred_high = pred_high.item()
pred_low = pred_low.item() pred_low = pred_low.item()
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: # Force a trade (choose BUY if upward potential >= downward, else SELL)
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 action = 2
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
action = 0
else: else:
action = 1 action = 0
_, _, _, done, _, _ = env.step(action) _, _, _, done, _, _ = env.step(action)
if done: if done:
print("Reached end of simulation window; resetting environment.") print("Reached end of simulation window; resetting environment.")

File diff suppressed because one or more lines are too long