This commit is contained in:
Dobromir Popov 2025-02-04 22:23:23 +02:00
parent f32f648bf0
commit 907468239a

View File

@ -22,6 +22,18 @@ import matplotlib.dates as mdates
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts):
"""
Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (i.e. in milliseconds),
it is divided by 1000.
"""
ts = float(ts)
if ts > 1e10: # Likely in milliseconds
ts = ts / 1000.0
return datetime.fromtimestamp(ts)
# --- Directories --- # --- Directories ---
LAST_DIR = os.path.join("models", "last") LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best") BEST_DIR = os.path.join("models", "best")
@ -219,11 +231,10 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
old_embed = old_state["timeframe_embed.weight"] old_embed = old_state["timeframe_embed.weight"]
new_embed = new_state["timeframe_embed.weight"] new_embed = new_state["timeframe_embed.weight"]
if old_embed.shape[0] < new_embed.shape[0]: if old_embed.shape[0] < new_embed.shape[0]:
# Copy the available rows and keep the remaining as initialized.
new_embed[:old_embed.shape[0]] = old_embed new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed old_state["timeframe_embed.weight"] = new_embed
# For channel_branches, if there are missing keys, load_state_dict with strict=False. # For channel_branches, missing keys are handled by strict=False.
model.load_state_dict(old_state, strict=False) model.load_state_dict(old_state, strict=False)
return checkpoint return checkpoint
@ -232,7 +243,7 @@ def update_live_html(candles, trade_history, epoch):
""" """
Generate a chart image that uses actual timestamps on the x-axis Generate a chart image that uses actual timestamps on the x-axis
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines) and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
is embedded in an HTML page that auto-refreshes every 10 seconds. is embedded in an HTML page that auto-refreshes every 1 seconds.
""" """
from io import BytesIO from io import BytesIO
import base64 import base64
@ -252,7 +263,7 @@ def update_live_html(candles, trade_history, epoch):
<html> <html>
<head> <head>
<meta charset="utf-8"> <meta charset="utf-8">
<meta http-equiv="refresh" content="10"> <meta http-equiv="refresh" content="1">
<title>Live Trading Chart - Epoch {epoch}</title> <title>Live Trading Chart - Epoch {epoch}</title>
<style> <style>
body {{ body {{
@ -291,15 +302,15 @@ def update_live_chart(ax, candles, trade_history):
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit. Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
""" """
ax.clear() ax.clear()
# Convert timestamps to datetime objects. # Use the helper to convert timestamps safely.
times = [datetime.fromtimestamp(candle["timestamp"]) for candle in candles] times = [convert_timestamp(candle["timestamp"]) for candle in candles]
close_prices = [candle["close"] for candle in candles] close_prices = [candle["close"] for candle in candles]
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
# Format x-axis date labels. # Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
for trade in trade_history: for trade in trade_history:
entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"]) entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"]) exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"] in_price = trade["entry_price"]
out_price = trade["exit_price"] out_price = trade["exit_price"]
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
@ -518,7 +529,6 @@ async def main():
except Exception as e: except Exception as e:
print("Failed to load optimizer state due to:", e) print("Failed to load optimizer state due to:", e)
print("Deleting all checkpoints and starting fresh.") print("Deleting all checkpoints and starting fresh.")
# Delete checkpoint files.
for chk_dir in [LAST_DIR, BEST_DIR]: for chk_dir in [LAST_DIR, BEST_DIR]:
for f in os.listdir(chk_dir): for f in os.listdir(chk_dir):
os.remove(os.path.join(chk_dir, f)) os.remove(os.path.join(chk_dir, f))