From a04a2d7677a35f09823416a9e42bd37557c6278b Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 5 Feb 2025 10:53:21 +0200 Subject: [PATCH] o3 sugegstions --- crypto/brian/index-deep-new.py | 43 ++++++++++++++++------------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 31dfba0..4404189 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -253,7 +253,8 @@ def manual_trade(env): When no sufficient action is taken by the model, use a fallback: Scan the remaining window for the global maximum and minimum. If the maximum occurs before the minimum, simulate a short trade; - otherwise simulate a long trade. Closes the trade at the candle where the chosen extreme occurs. + otherwise simulate a long trade. + Use the candle "close" prices to compute trade reward. """ current_index = env.current_index if current_index >= len(env.candle_window) - 1: @@ -273,8 +274,8 @@ def manual_trade(env): min_val = low_j i_min = j if i_max < i_min: - entry_price = env.candle_window[current_index]["open"] - exit_price = env.candle_window[i_min]["open"] + entry_price = env.candle_window[current_index]["close"] + exit_price = env.candle_window[i_min]["close"] reward = entry_price - exit_price trade = { "entry_index": current_index, @@ -284,8 +285,8 @@ def manual_trade(env): "pnl": reward } else: - entry_price = env.candle_window[current_index]["open"] - exit_price = env.candle_window[i_max]["open"] + entry_price = env.candle_window[current_index]["close"] + exit_price = env.candle_window[i_max]["close"] reward = exit_price - entry_price trade = { "entry_index": current_index, @@ -302,19 +303,18 @@ def simulate_trades_1s(env): """ When the main timeframe is 1s, scan the entire remaining window to detect local extrema. If at least two extrema are found, pair consecutive extrema as trades. - If none (or too few) are found, fallback to manual_trade. + Use the candle "close" prices for trade reward calculation. + If too few extrema are found, fallback to manual_trade. """ n = len(env.candle_window) extrema = [] for i in range(env.current_index, n): - # Add first and last points. if i == env.current_index or i == n-1: extrema.append(i) else: prev = env.candle_window[i-1]["close"] curr = env.candle_window[i]["close"] nex = env.candle_window[i+1]["close"] - # A valley or a peak. if curr < prev and curr < nex: extrema.append(i) elif curr > prev and curr > nex: @@ -322,12 +322,11 @@ def simulate_trades_1s(env): if len(extrema) < 2: manual_trade(env) return - # Process consecutive extrema into trades. for j in range(len(extrema)-1): entry_idx = extrema[j] - exit_idx = extrema[j+1] - entry_price = env.candle_window[entry_idx]["open"] - exit_price = env.candle_window[exit_idx]["open"] + exit_idx = extrema[j+1] + entry_price = env.candle_window[entry_idx]["close"] + exit_price = env.candle_window[exit_idx]["close"] if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]: reward = exit_price - entry_price else: @@ -348,7 +347,7 @@ def simulate_trades(model, env, device, args): Simulate trades over the current sliding window. If the main timeframe is 1s, use local extrema detection. Otherwise, check if the model's predicted potentials exceed the threshold. - If so, execute the model decision; otherwise, call the manual_trade override. + Use manual_trade if the model's signal is too weak. """ if args.main_tf == "1s": simulate_trades_1s(env) @@ -381,8 +380,8 @@ def update_live_html(candles, trade_history, epoch, loss): """ Generate an HTML page with a live chart. The chart displays price (line) and volume (bars on a secondary y-axis), - and includes buy/sell markers with dotted connecting lines. - The page title shows the epoch, loss, and PnL. + and includes trade markers with dotted connecting lines. + The title shows the epoch, loss, and total PnL. The page auto-refreshes every 1 second. """ from io import BytesIO @@ -390,7 +389,7 @@ def update_live_html(candles, trade_history, epoch, loss): fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) - ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}") + ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) @@ -423,7 +422,7 @@ def update_live_html(candles, trade_history, epoch, loss):
-

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}

+

Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}

Live Chart
@@ -438,7 +437,7 @@ def update_live_chart(ax, candles, trade_history): """ Plot the price chart with actual timestamps (date and time in short format) and volume on a secondary y-axis. - Mark BUY (green) and SELL (red) points with dotted lines connecting. + Mark trade entry (green) and exit (red) points, with dotted lines connecting them. """ ax.clear() times = [convert_timestamp(candle["timestamp"]) for candle in candles] @@ -448,7 +447,6 @@ def update_live_chart(ax, candles, trade_history): ax.set_ylabel("Price") # Use short format for date and time. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) - # Plot volume on secondary y-axis. ax2 = ax.twinx() volumes = [candle["volume"] for candle in candles] if len(times) > 1: @@ -458,7 +456,6 @@ def update_live_chart(ax, candles, trade_history): bar_width = 0.01 ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") ax2.set_ylabel("Volume") - # Plot trade markers. for trade in trade_history: entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) @@ -467,7 +464,6 @@ def update_live_chart(ax, candles, trade_history): ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") - # Combine legends from both axes. lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines + lines2, labels + labels2) @@ -534,7 +530,8 @@ class BacktestEnvironment: self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL (close trade) - exit_price = next_candle["open"] + # Use the "close" price to compute reward. + exit_price = next_candle["close"] reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], @@ -563,7 +560,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s state = env.get_state(i) current_open = env.candle_window[i]["open"] actual_high = env.candle_window[i+1]["high"] - actual_low = env.candle_window[i+1]["low"] + actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids)