This commit is contained in:
Dobromir Popov 2025-02-05 10:26:55 +02:00
parent 13c4f72b01
commit a58f9810bd

View File

@ -22,11 +22,30 @@ import matplotlib.dates as mdates
from dotenv import load_dotenv
load_dotenv()
# --- Fetch Real 1s Data (if main_tf=="1s") ---
def fetch_real_1s_data():
"""
Fetch real 1-second candle data from your API.
Replace the URL and parameters with those required by your data provider.
Expected data format: a list of dictionaries with keys: "timestamp", "open", "high", "low", "close", "volume"
"""
import requests
url = "https://api.example.com/1s-data" # <-- Replace with your actual endpoint.
try:
response = requests.get(url)
response.raise_for_status()
data = response.json()
print("Fetched real 1s data successfully.")
return data
except Exception as e:
print("Failed to fetch real 1s data:", e)
return []
# --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts):
"""
Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (e.g. in milliseconds),
If the timestamp is abnormally high (e.g. in milliseconds),
it is divided by 1000.
"""
ts = float(ts)
@ -309,7 +328,6 @@ def simulate_trades_1s(env):
exit_idx = extrema[j+1]
entry_price = env.candle_window[entry_idx]["open"]
exit_price = env.candle_window[exit_idx]["open"]
# If the entry candles close is lower than exit candles close, this is a long trade.
if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]:
reward = exit_price - entry_price
else:
@ -330,8 +348,7 @@ def simulate_trades(model, env, device, args):
Simulate trades over the current sliding window.
If the main timeframe is 1s, use local extrema detection.
Otherwise, check if the model's predicted potentials exceed the threshold.
- If so, execute the model decision.
- Otherwise, call the manual_trade override.
If so, execute the model decision; otherwise, call the manual_trade override.
"""
if args.main_tf == "1s":
simulate_trades_1s(env)
@ -359,12 +376,13 @@ def simulate_trades(model, env, device, args):
if env.current_index >= len(env.candle_window) - 1:
break
# --- Live HTML Chart Update (with Volume) ---
def update_live_html(candles, trade_history, epoch):
# --- Live HTML Chart Update (with Volume and Loss) ---
def update_live_html(candles, trade_history, epoch, loss):
"""
Generate an HTML page with a live chart.
The chart displays price (line) and volume (bar chart on a secondary y-axis),
and includes buy/sell markers with dotted lines connecting entries and exits.
The chart displays price (line) and volume (bars on a secondary y-axis),
and includes buy/sell markers with dotted connecting lines.
The page title shows the epoch, loss, and PnL.
The page auto-refreshes every 1 second.
"""
from io import BytesIO
@ -372,7 +390,7 @@ def update_live_html(candles, trade_history, epoch):
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}")
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}")
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
@ -405,7 +423,7 @@ def update_live_html(candles, trade_history, epoch):
</head>
<body>
<div class="chart-container">
<h2>Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}</h2>
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div>
</body>
@ -415,11 +433,12 @@ def update_live_html(candles, trade_history, epoch):
f.write(html_content)
print("Updated live_chart.html.")
# --- Chart Drawing Helpers (with Volume) ---
# --- Chart Drawing Helpers (with Volume and Date+Time) ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the price chart with actual timestamps and volume on a secondary y-axis.
Mark BUY (green) and SELL (red) points and connect them with dotted lines.
Plot the price chart with actual timestamps (date and time in short format)
and volume on a secondary y-axis.
Mark BUY (green) and SELL (red) points with dotted lines connecting.
"""
ax.clear()
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
@ -427,7 +446,8 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.set_xlabel("Time")
ax.set_ylabel("Price")
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Use short format for date and time.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
# Plot volume on secondary y-axis.
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
@ -447,7 +467,7 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
# Combine legends.
# Combine legends from both axes.
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2)
@ -455,7 +475,6 @@ def update_live_chart(ax, candles, trade_history):
fig = ax.get_figure()
fig.autofmt_xdate()
# --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
@ -568,8 +587,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}")
save_checkpoint(model, optimizer, epoch, loss_accum)
simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1)
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss)
# --- Live Plotting (for Live Mode) ---
def live_preview_loop(candles, env):
@ -607,6 +625,11 @@ async def main():
print("Using device:", device)
# Load cached candles.
candles_dict = load_candles_cache(CACHE_FILE)
# If the desired main timeframe is 1s, attempt to fetch real 1s data.
if args.main_tf == "1s":
real_1s_data = fetch_real_1s_data()
if real_1s_data:
candles_dict["1s"] = real_1s_data
if not candles_dict:
print("No historical candle data available for backtesting.")
return