This commit is contained in:
Dobromir Popov 2025-02-05 10:26:55 +02:00
parent 13c4f72b01
commit a58f9810bd

View File

@ -22,11 +22,30 @@ import matplotlib.dates as mdates
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# --- Fetch Real 1s Data (if main_tf=="1s") ---
def fetch_real_1s_data():
"""
Fetch real 1-second candle data from your API.
Replace the URL and parameters with those required by your data provider.
Expected data format: a list of dictionaries with keys: "timestamp", "open", "high", "low", "close", "volume"
"""
import requests
url = "https://api.example.com/1s-data" # <-- Replace with your actual endpoint.
try:
response = requests.get(url)
response.raise_for_status()
data = response.json()
print("Fetched real 1s data successfully.")
return data
except Exception as e:
print("Failed to fetch real 1s data:", e)
return []
# --- Helper Function for Timestamp Conversion --- # --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts): def convert_timestamp(ts):
""" """
Safely converts a timestamp to a datetime object. Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (e.g. in milliseconds), If the timestamp is abnormally high (e.g. in milliseconds),
it is divided by 1000. it is divided by 1000.
""" """
ts = float(ts) ts = float(ts)
@ -309,7 +328,6 @@ def simulate_trades_1s(env):
exit_idx = extrema[j+1] exit_idx = extrema[j+1]
entry_price = env.candle_window[entry_idx]["open"] entry_price = env.candle_window[entry_idx]["open"]
exit_price = env.candle_window[exit_idx]["open"] exit_price = env.candle_window[exit_idx]["open"]
# If the entry candles close is lower than exit candles close, this is a long trade.
if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]: if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]:
reward = exit_price - entry_price reward = exit_price - entry_price
else: else:
@ -330,8 +348,7 @@ def simulate_trades(model, env, device, args):
Simulate trades over the current sliding window. Simulate trades over the current sliding window.
If the main timeframe is 1s, use local extrema detection. If the main timeframe is 1s, use local extrema detection.
Otherwise, check if the model's predicted potentials exceed the threshold. Otherwise, check if the model's predicted potentials exceed the threshold.
- If so, execute the model decision. If so, execute the model decision; otherwise, call the manual_trade override.
- Otherwise, call the manual_trade override.
""" """
if args.main_tf == "1s": if args.main_tf == "1s":
simulate_trades_1s(env) simulate_trades_1s(env)
@ -359,12 +376,13 @@ def simulate_trades(model, env, device, args):
if env.current_index >= len(env.candle_window) - 1: if env.current_index >= len(env.candle_window) - 1:
break break
# --- Live HTML Chart Update (with Volume) --- # --- Live HTML Chart Update (with Volume and Loss) ---
def update_live_html(candles, trade_history, epoch): def update_live_html(candles, trade_history, epoch, loss):
""" """
Generate an HTML page with a live chart. Generate an HTML page with a live chart.
The chart displays price (line) and volume (bar chart on a secondary y-axis), The chart displays price (line) and volume (bars on a secondary y-axis),
and includes buy/sell markers with dotted lines connecting entries and exits. and includes buy/sell markers with dotted connecting lines.
The page title shows the epoch, loss, and PnL.
The page auto-refreshes every 1 second. The page auto-refreshes every 1 second.
""" """
from io import BytesIO from io import BytesIO
@ -372,7 +390,7 @@ def update_live_html(candles, trade_history, epoch):
fig, ax = plt.subplots(figsize=(12, 6)) fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history) update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}") ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}")
buf = BytesIO() buf = BytesIO()
fig.savefig(buf, format='png') fig.savefig(buf, format='png')
plt.close(fig) plt.close(fig)
@ -405,7 +423,7 @@ def update_live_html(candles, trade_history, epoch):
</head> </head>
<body> <body>
<div class="chart-container"> <div class="chart-container">
<h2>Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}</h2> <h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/> <img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div> </div>
</body> </body>
@ -415,11 +433,12 @@ def update_live_html(candles, trade_history, epoch):
f.write(html_content) f.write(html_content)
print("Updated live_chart.html.") print("Updated live_chart.html.")
# --- Chart Drawing Helpers (with Volume) --- # --- Chart Drawing Helpers (with Volume and Date+Time) ---
def update_live_chart(ax, candles, trade_history): def update_live_chart(ax, candles, trade_history):
""" """
Plot the price chart with actual timestamps and volume on a secondary y-axis. Plot the price chart with actual timestamps (date and time in short format)
Mark BUY (green) and SELL (red) points and connect them with dotted lines. and volume on a secondary y-axis.
Mark BUY (green) and SELL (red) points with dotted lines connecting.
""" """
ax.clear() ax.clear()
times = [convert_timestamp(candle["timestamp"]) for candle in candles] times = [convert_timestamp(candle["timestamp"]) for candle in candles]
@ -427,7 +446,8 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.set_xlabel("Time") ax.set_xlabel("Time")
ax.set_ylabel("Price") ax.set_ylabel("Price")
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) # Use short format for date and time.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
# Plot volume on secondary y-axis. # Plot volume on secondary y-axis.
ax2 = ax.twinx() ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles] volumes = [candle["volume"] for candle in candles]
@ -447,7 +467,7 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
# Combine legends. # Combine legends from both axes.
lines, labels = ax.get_legend_handles_labels() lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2) ax.legend(lines + lines2, labels + labels2)
@ -455,7 +475,6 @@ def update_live_chart(ax, candles, trade_history):
fig = ax.get_figure() fig = ax.get_figure()
fig.autofmt_xdate() fig.autofmt_xdate()
# --- Backtest Environment with Sliding Window and Order Info --- # --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None): def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
@ -568,8 +587,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}") print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}")
save_checkpoint(model, optimizer, epoch, loss_accum) save_checkpoint(model, optimizer, epoch, loss_accum)
simulate_trades(model, env, device, args) simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1) update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss)
# --- Live Plotting (for Live Mode) --- # --- Live Plotting (for Live Mode) ---
def live_preview_loop(candles, env): def live_preview_loop(candles, env):
@ -607,6 +625,11 @@ async def main():
print("Using device:", device) print("Using device:", device)
# Load cached candles. # Load cached candles.
candles_dict = load_candles_cache(CACHE_FILE) candles_dict = load_candles_cache(CACHE_FILE)
# If the desired main timeframe is 1s, attempt to fetch real 1s data.
if args.main_tf == "1s":
real_1s_data = fetch_real_1s_data()
if real_1s_data:
candles_dict["1s"] = real_1s_data
if not candles_dict: if not candles_dict:
print("No historical candle data available for backtesting.") print("No historical candle data available for backtesting.")
return return