o3 sugegstions

This commit is contained in:
Dobromir Popov 2025-02-05 10:53:21 +02:00
parent a58f9810bd
commit a04a2d7677

View File

@ -253,7 +253,8 @@ def manual_trade(env):
When no sufficient action is taken by the model, use a fallback: When no sufficient action is taken by the model, use a fallback:
Scan the remaining window for the global maximum and minimum. Scan the remaining window for the global maximum and minimum.
If the maximum occurs before the minimum, simulate a short trade; If the maximum occurs before the minimum, simulate a short trade;
otherwise simulate a long trade. Closes the trade at the candle where the chosen extreme occurs. otherwise simulate a long trade.
Use the candle "close" prices to compute trade reward.
""" """
current_index = env.current_index current_index = env.current_index
if current_index >= len(env.candle_window) - 1: if current_index >= len(env.candle_window) - 1:
@ -273,8 +274,8 @@ def manual_trade(env):
min_val = low_j min_val = low_j
i_min = j i_min = j
if i_max < i_min: if i_max < i_min:
entry_price = env.candle_window[current_index]["open"] entry_price = env.candle_window[current_index]["close"]
exit_price = env.candle_window[i_min]["open"] exit_price = env.candle_window[i_min]["close"]
reward = entry_price - exit_price reward = entry_price - exit_price
trade = { trade = {
"entry_index": current_index, "entry_index": current_index,
@ -284,8 +285,8 @@ def manual_trade(env):
"pnl": reward "pnl": reward
} }
else: else:
entry_price = env.candle_window[current_index]["open"] entry_price = env.candle_window[current_index]["close"]
exit_price = env.candle_window[i_max]["open"] exit_price = env.candle_window[i_max]["close"]
reward = exit_price - entry_price reward = exit_price - entry_price
trade = { trade = {
"entry_index": current_index, "entry_index": current_index,
@ -302,19 +303,18 @@ def simulate_trades_1s(env):
""" """
When the main timeframe is 1s, scan the entire remaining window to detect local extrema. When the main timeframe is 1s, scan the entire remaining window to detect local extrema.
If at least two extrema are found, pair consecutive extrema as trades. If at least two extrema are found, pair consecutive extrema as trades.
If none (or too few) are found, fallback to manual_trade. Use the candle "close" prices for trade reward calculation.
If too few extrema are found, fallback to manual_trade.
""" """
n = len(env.candle_window) n = len(env.candle_window)
extrema = [] extrema = []
for i in range(env.current_index, n): for i in range(env.current_index, n):
# Add first and last points.
if i == env.current_index or i == n-1: if i == env.current_index or i == n-1:
extrema.append(i) extrema.append(i)
else: else:
prev = env.candle_window[i-1]["close"] prev = env.candle_window[i-1]["close"]
curr = env.candle_window[i]["close"] curr = env.candle_window[i]["close"]
nex = env.candle_window[i+1]["close"] nex = env.candle_window[i+1]["close"]
# A valley or a peak.
if curr < prev and curr < nex: if curr < prev and curr < nex:
extrema.append(i) extrema.append(i)
elif curr > prev and curr > nex: elif curr > prev and curr > nex:
@ -322,12 +322,11 @@ def simulate_trades_1s(env):
if len(extrema) < 2: if len(extrema) < 2:
manual_trade(env) manual_trade(env)
return return
# Process consecutive extrema into trades.
for j in range(len(extrema)-1): for j in range(len(extrema)-1):
entry_idx = extrema[j] entry_idx = extrema[j]
exit_idx = extrema[j+1] exit_idx = extrema[j+1]
entry_price = env.candle_window[entry_idx]["open"] entry_price = env.candle_window[entry_idx]["close"]
exit_price = env.candle_window[exit_idx]["open"] exit_price = env.candle_window[exit_idx]["close"]
if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]: if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]:
reward = exit_price - entry_price reward = exit_price - entry_price
else: else:
@ -348,7 +347,7 @@ def simulate_trades(model, env, device, args):
Simulate trades over the current sliding window. Simulate trades over the current sliding window.
If the main timeframe is 1s, use local extrema detection. If the main timeframe is 1s, use local extrema detection.
Otherwise, check if the model's predicted potentials exceed the threshold. Otherwise, check if the model's predicted potentials exceed the threshold.
If so, execute the model decision; otherwise, call the manual_trade override. Use manual_trade if the model's signal is too weak.
""" """
if args.main_tf == "1s": if args.main_tf == "1s":
simulate_trades_1s(env) simulate_trades_1s(env)
@ -381,8 +380,8 @@ def update_live_html(candles, trade_history, epoch, loss):
""" """
Generate an HTML page with a live chart. Generate an HTML page with a live chart.
The chart displays price (line) and volume (bars on a secondary y-axis), The chart displays price (line) and volume (bars on a secondary y-axis),
and includes buy/sell markers with dotted connecting lines. and includes trade markers with dotted connecting lines.
The page title shows the epoch, loss, and PnL. The title shows the epoch, loss, and total PnL.
The page auto-refreshes every 1 second. The page auto-refreshes every 1 second.
""" """
from io import BytesIO from io import BytesIO
@ -390,7 +389,7 @@ def update_live_html(candles, trade_history, epoch, loss):
fig, ax = plt.subplots(figsize=(12, 6)) fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history) update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}") ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}")
buf = BytesIO() buf = BytesIO()
fig.savefig(buf, format='png') fig.savefig(buf, format='png')
plt.close(fig) plt.close(fig)
@ -423,7 +422,7 @@ def update_live_html(candles, trade_history, epoch, loss):
</head> </head>
<body> <body>
<div class="chart-container"> <div class="chart-container">
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}</h2> <h2>Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/> <img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div> </div>
</body> </body>
@ -438,7 +437,7 @@ def update_live_chart(ax, candles, trade_history):
""" """
Plot the price chart with actual timestamps (date and time in short format) Plot the price chart with actual timestamps (date and time in short format)
and volume on a secondary y-axis. and volume on a secondary y-axis.
Mark BUY (green) and SELL (red) points with dotted lines connecting. Mark trade entry (green) and exit (red) points, with dotted lines connecting them.
""" """
ax.clear() ax.clear()
times = [convert_timestamp(candle["timestamp"]) for candle in candles] times = [convert_timestamp(candle["timestamp"]) for candle in candles]
@ -448,7 +447,6 @@ def update_live_chart(ax, candles, trade_history):
ax.set_ylabel("Price") ax.set_ylabel("Price")
# Use short format for date and time. # Use short format for date and time.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
# Plot volume on secondary y-axis.
ax2 = ax.twinx() ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles] volumes = [candle["volume"] for candle in candles]
if len(times) > 1: if len(times) > 1:
@ -458,7 +456,6 @@ def update_live_chart(ax, candles, trade_history):
bar_width = 0.01 bar_width = 0.01
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
ax2.set_ylabel("Volume") ax2.set_ylabel("Volume")
# Plot trade markers.
for trade in trade_history: for trade in trade_history:
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
@ -467,7 +464,6 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
# Combine legends from both axes.
lines, labels = ax.get_legend_handles_labels() lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2) ax.legend(lines + lines2, labels + labels2)
@ -534,7 +530,8 @@ class BacktestEnvironment:
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else: else:
if action == 0: # SELL (close trade) if action == 0: # SELL (close trade)
exit_price = next_candle["open"] # Use the "close" price to compute reward.
exit_price = next_candle["close"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
"entry_index": self.position["entry_index"], "entry_index": self.position["entry_index"],