more strong model response
This commit is contained in:
parent
907468239a
commit
75c4d6602a
@ -327,6 +327,8 @@ def update_live_chart(ax, candles, trade_history):
|
|||||||
def simulate_trades(model, env, device, args):
|
def simulate_trades(model, env, device, args):
|
||||||
"""
|
"""
|
||||||
Run a simulation on the current sliding window using the model's outputs and a decision rule.
|
Run a simulation on the current sliding window using the model's outputs and a decision rule.
|
||||||
|
Here we force the simulation to always take an action by comparing the predicted potentials,
|
||||||
|
ensuring that the model is forced to trade (either BUY or SELL) rather than HOLD.
|
||||||
This simulation updates env.trade_history and is used for visualization only.
|
This simulation updates env.trade_history and is used for visualization only.
|
||||||
"""
|
"""
|
||||||
env.reset() # resets the window and index
|
env.reset() # resets the window and index
|
||||||
@ -339,12 +341,11 @@ def simulate_trades(model, env, device, args):
|
|||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
pred_high = pred_high.item()
|
pred_high = pred_high.item()
|
||||||
pred_low = pred_low.item()
|
pred_low = pred_low.item()
|
||||||
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
|
# Force a trade: choose BUY if predicted up-move is higher (or equal), else SELL.
|
||||||
|
if (pred_high - current_open) >= (current_open - pred_low):
|
||||||
action = 2 # BUY
|
action = 2 # BUY
|
||||||
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
|
|
||||||
action = 0 # SELL
|
|
||||||
else:
|
else:
|
||||||
action = 1 # HOLD
|
action = 0 # SELL
|
||||||
_, _, _, done, _, _ = env.step(action)
|
_, _, _, done, _, _ = env.step(action)
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
@ -481,9 +482,9 @@ def parse_args():
|
|||||||
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
|
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
|
||||||
parser.add_argument('--epochs', type=int, default=100)
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
|
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade (used in loss).")
|
||||||
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
|
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
|
||||||
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken.")
|
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).")
|
||||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -545,7 +546,7 @@ async def main():
|
|||||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
|
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
|
||||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||||
preview_thread.start()
|
preview_thread.start()
|
||||||
print("Starting live trading loop. (Using model-based decision rule.)")
|
print("Starting live trading loop. (Forcing trade actions based on highest potential.)")
|
||||||
while True:
|
while True:
|
||||||
state = env.get_state(env.current_index)
|
state = env.get_state(env.current_index)
|
||||||
current_open = env.candle_window[env.current_index]["open"]
|
current_open = env.candle_window[env.current_index]["open"]
|
||||||
@ -554,12 +555,11 @@ async def main():
|
|||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
pred_high = pred_high.item()
|
pred_high = pred_high.item()
|
||||||
pred_low = pred_low.item()
|
pred_low = pred_low.item()
|
||||||
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
|
# Force a trade (choose BUY if upward potential >= downward, else SELL)
|
||||||
|
if (pred_high - current_open) >= (current_open - pred_low):
|
||||||
action = 2
|
action = 2
|
||||||
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
|
|
||||||
action = 0
|
|
||||||
else:
|
else:
|
||||||
action = 1
|
action = 0
|
||||||
_, _, _, done, _, _ = env.step(action)
|
_, _, _, done, _, _ = env.step(action)
|
||||||
if done:
|
if done:
|
||||||
print("Reached end of simulation window; resetting environment.")
|
print("Reached end of simulation window; resetting environment.")
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user