try to fix total pnl
This commit is contained in:
parent
a04a2d7677
commit
d29cc312fd
@ -59,6 +59,7 @@ BEST_DIR = os.path.join("models", "best")
|
|||||||
os.makedirs(LAST_DIR, exist_ok=True)
|
os.makedirs(LAST_DIR, exist_ok=True)
|
||||||
os.makedirs(BEST_DIR, exist_ok=True)
|
os.makedirs(BEST_DIR, exist_ok=True)
|
||||||
CACHE_FILE = "candles_cache.json"
|
CACHE_FILE = "candles_cache.json"
|
||||||
|
TRAINING_CACHE_FILE = "training_cache.json"
|
||||||
|
|
||||||
# --- Constants ---
|
# --- Constants ---
|
||||||
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
|
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
|
||||||
@ -66,6 +67,25 @@ NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
|||||||
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
|
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
|
||||||
ORDER_CHANNELS = 1 # One extra channel for order information.
|
ORDER_CHANNELS = 1 # One extra channel for order information.
|
||||||
|
|
||||||
|
# --- Training Cache Helpers ---
|
||||||
|
def load_training_cache(filename):
|
||||||
|
if os.path.exists(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
cache = json.load(f)
|
||||||
|
print(f"Loaded training cache from {filename}.")
|
||||||
|
return cache
|
||||||
|
except Exception as e:
|
||||||
|
print("Error loading training cache:", e)
|
||||||
|
return {"total_pnl": 0.0}
|
||||||
|
|
||||||
|
def save_training_cache(filename, cache):
|
||||||
|
try:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(cache, f)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error saving training cache:", e)
|
||||||
|
|
||||||
# --- Positional Encoding Module ---
|
# --- Positional Encoding Module ---
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||||
@ -376,7 +396,7 @@ def simulate_trades(model, env, device, args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# --- Live HTML Chart Update (with Volume and Loss) ---
|
# --- Live HTML Chart Update (with Volume and Loss) ---
|
||||||
def update_live_html(candles, trade_history, epoch, loss):
|
def update_live_html(candles, trade_history, epoch, loss, total_pnl):
|
||||||
"""
|
"""
|
||||||
Generate an HTML page with a live chart.
|
Generate an HTML page with a live chart.
|
||||||
The chart displays price (line) and volume (bars on a secondary y-axis),
|
The chart displays price (line) and volume (bars on a secondary y-axis),
|
||||||
@ -389,7 +409,7 @@ def update_live_html(candles, trade_history, epoch, loss):
|
|||||||
fig, ax = plt.subplots(figsize=(12, 6))
|
fig, ax = plt.subplots(figsize=(12, 6))
|
||||||
update_live_chart(ax, candles, trade_history)
|
update_live_chart(ax, candles, trade_history)
|
||||||
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
|
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
|
||||||
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}")
|
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}")
|
||||||
buf = BytesIO()
|
buf = BytesIO()
|
||||||
fig.savefig(buf, format='png')
|
fig.savefig(buf, format='png')
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
@ -422,7 +442,7 @@ def update_live_html(candles, trade_history, epoch, loss):
|
|||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="chart-container">
|
<div class="chart-container">
|
||||||
<h2>Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}</h2>
|
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}</h2>
|
||||||
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
|
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
@ -445,7 +465,6 @@ def update_live_chart(ax, candles, trade_history):
|
|||||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||||
ax.set_xlabel("Time")
|
ax.set_xlabel("Time")
|
||||||
ax.set_ylabel("Price")
|
ax.set_ylabel("Price")
|
||||||
# Use short format for date and time.
|
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
|
||||||
ax2 = ax.twinx()
|
ax2 = ax.twinx()
|
||||||
volumes = [candle["volume"] for candle in candles]
|
volumes = [candle["volume"] for candle in candles]
|
||||||
@ -530,8 +549,7 @@ class BacktestEnvironment:
|
|||||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||||
else:
|
else:
|
||||||
if action == 0: # SELL (close trade)
|
if action == 0: # SELL (close trade)
|
||||||
# Use the "close" price to compute reward.
|
exit_price = next_candle["close"] # use close price
|
||||||
exit_price = next_candle["close"]
|
|
||||||
reward = exit_price - self.position["entry_price"]
|
reward = exit_price - self.position["entry_price"]
|
||||||
trade = {
|
trade = {
|
||||||
"entry_index": self.position["entry_index"],
|
"entry_index": self.position["entry_index"],
|
||||||
@ -551,7 +569,9 @@ class BacktestEnvironment:
|
|||||||
# --- Enhanced Training Loop ---
|
# --- Enhanced Training Loop ---
|
||||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||||
lambda_trade = args.lambda_trade
|
lambda_trade = args.lambda_trade
|
||||||
total_pnl = 0.0
|
# Load any saved total PnL from training cache:
|
||||||
|
training_cache = load_training_cache(TRAINING_CACHE_FILE)
|
||||||
|
total_pnl = training_cache.get("total_pnl", 0.0)
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
env.reset()
|
env.reset()
|
||||||
loss_accum = 0.0
|
loss_accum = 0.0
|
||||||
@ -580,11 +600,18 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
|||||||
loss_accum += loss.item()
|
loss_accum += loss.item()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
epoch_loss = loss_accum / steps
|
epoch_loss = loss_accum / steps
|
||||||
total_pnl += sum(trade["pnl"] for trade in env.trade_history)
|
# If no trades occurred during the epoch, multiply the loss by 3.
|
||||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}")
|
if len(env.trade_history) == 0:
|
||||||
|
epoch_loss *= 3
|
||||||
|
epoch_pnl = sum(trade["pnl"] for trade in env.trade_history)
|
||||||
|
total_pnl += epoch_pnl
|
||||||
|
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}")
|
||||||
save_checkpoint(model, optimizer, epoch, loss_accum)
|
save_checkpoint(model, optimizer, epoch, loss_accum)
|
||||||
simulate_trades(model, env, device, args)
|
simulate_trades(model, env, device, args)
|
||||||
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss)
|
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl)
|
||||||
|
# Update training cache with the new total PnL:
|
||||||
|
training_cache["total_pnl"] = total_pnl
|
||||||
|
save_training_cache(TRAINING_CACHE_FILE, training_cache)
|
||||||
|
|
||||||
# --- Live Plotting (for Live Mode) ---
|
# --- Live Plotting (for Live Mode) ---
|
||||||
def live_preview_loop(candles, env):
|
def live_preview_loop(candles, env):
|
||||||
|
File diff suppressed because one or more lines are too long
1
crypto/brian/training_cache.json
Normal file
1
crypto/brian/training_cache.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"total_pnl": 0.0}
|
Loading…
x
Reference in New Issue
Block a user