wip
This commit is contained in:
parent
f32f648bf0
commit
907468239a
@ -22,6 +22,18 @@ import matplotlib.dates as mdates
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# --- Helper Function for Timestamp Conversion ---
|
||||||
|
def convert_timestamp(ts):
|
||||||
|
"""
|
||||||
|
Safely converts a timestamp to a datetime object.
|
||||||
|
If the timestamp is abnormally high (i.e. in milliseconds),
|
||||||
|
it is divided by 1000.
|
||||||
|
"""
|
||||||
|
ts = float(ts)
|
||||||
|
if ts > 1e10: # Likely in milliseconds
|
||||||
|
ts = ts / 1000.0
|
||||||
|
return datetime.fromtimestamp(ts)
|
||||||
|
|
||||||
# --- Directories ---
|
# --- Directories ---
|
||||||
LAST_DIR = os.path.join("models", "last")
|
LAST_DIR = os.path.join("models", "last")
|
||||||
BEST_DIR = os.path.join("models", "best")
|
BEST_DIR = os.path.join("models", "best")
|
||||||
@ -219,11 +231,10 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
|||||||
old_embed = old_state["timeframe_embed.weight"]
|
old_embed = old_state["timeframe_embed.weight"]
|
||||||
new_embed = new_state["timeframe_embed.weight"]
|
new_embed = new_state["timeframe_embed.weight"]
|
||||||
if old_embed.shape[0] < new_embed.shape[0]:
|
if old_embed.shape[0] < new_embed.shape[0]:
|
||||||
# Copy the available rows and keep the remaining as initialized.
|
|
||||||
new_embed[:old_embed.shape[0]] = old_embed
|
new_embed[:old_embed.shape[0]] = old_embed
|
||||||
old_state["timeframe_embed.weight"] = new_embed
|
old_state["timeframe_embed.weight"] = new_embed
|
||||||
|
|
||||||
# For channel_branches, if there are missing keys, load_state_dict with strict=False.
|
# For channel_branches, missing keys are handled by strict=False.
|
||||||
model.load_state_dict(old_state, strict=False)
|
model.load_state_dict(old_state, strict=False)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
@ -232,7 +243,7 @@ def update_live_html(candles, trade_history, epoch):
|
|||||||
"""
|
"""
|
||||||
Generate a chart image that uses actual timestamps on the x-axis
|
Generate a chart image that uses actual timestamps on the x-axis
|
||||||
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
|
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
|
||||||
is embedded in an HTML page that auto-refreshes every 10 seconds.
|
is embedded in an HTML page that auto-refreshes every 1 seconds.
|
||||||
"""
|
"""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
import base64
|
||||||
@ -252,7 +263,7 @@ def update_live_html(candles, trade_history, epoch):
|
|||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8">
|
<meta charset="utf-8">
|
||||||
<meta http-equiv="refresh" content="10">
|
<meta http-equiv="refresh" content="1">
|
||||||
<title>Live Trading Chart - Epoch {epoch}</title>
|
<title>Live Trading Chart - Epoch {epoch}</title>
|
||||||
<style>
|
<style>
|
||||||
body {{
|
body {{
|
||||||
@ -291,15 +302,15 @@ def update_live_chart(ax, candles, trade_history):
|
|||||||
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
|
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
|
||||||
"""
|
"""
|
||||||
ax.clear()
|
ax.clear()
|
||||||
# Convert timestamps to datetime objects.
|
# Use the helper to convert timestamps safely.
|
||||||
times = [datetime.fromtimestamp(candle["timestamp"]) for candle in candles]
|
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
|
||||||
close_prices = [candle["close"] for candle in candles]
|
close_prices = [candle["close"] for candle in candles]
|
||||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||||
# Format x-axis date labels.
|
# Format x-axis date labels.
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
||||||
for trade in trade_history:
|
for trade in trade_history:
|
||||||
entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"])
|
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
|
||||||
exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"])
|
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
|
||||||
in_price = trade["entry_price"]
|
in_price = trade["entry_price"]
|
||||||
out_price = trade["exit_price"]
|
out_price = trade["exit_price"]
|
||||||
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
|
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
|
||||||
@ -518,7 +529,6 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to load optimizer state due to:", e)
|
print("Failed to load optimizer state due to:", e)
|
||||||
print("Deleting all checkpoints and starting fresh.")
|
print("Deleting all checkpoints and starting fresh.")
|
||||||
# Delete checkpoint files.
|
|
||||||
for chk_dir in [LAST_DIR, BEST_DIR]:
|
for chk_dir in [LAST_DIR, BEST_DIR]:
|
||||||
for f in os.listdir(chk_dir):
|
for f in os.listdir(chk_dir):
|
||||||
os.remove(os.path.join(chk_dir, f))
|
os.remove(os.path.join(chk_dir, f))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user