try to fix total pnl

This commit is contained in:
Dobromir Popov 2025-02-05 11:07:20 +02:00
parent a04a2d7677
commit d29cc312fd
3 changed files with 42 additions and 14 deletions

View File

@ -59,6 +59,7 @@ BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
TRAINING_CACHE_FILE = "training_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
@ -66,6 +67,25 @@ NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
ORDER_CHANNELS = 1 # One extra channel for order information.
# --- Training Cache Helpers ---
def load_training_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
cache = json.load(f)
print(f"Loaded training cache from {filename}.")
return cache
except Exception as e:
print("Error loading training cache:", e)
return {"total_pnl": 0.0}
def save_training_cache(filename, cache):
try:
with open(filename, "w") as f:
json.dump(cache, f)
except Exception as e:
print("Error saving training cache:", e)
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
@ -376,7 +396,7 @@ def simulate_trades(model, env, device, args):
break
# --- Live HTML Chart Update (with Volume and Loss) ---
def update_live_html(candles, trade_history, epoch, loss):
def update_live_html(candles, trade_history, epoch, loss, total_pnl):
"""
Generate an HTML page with a live chart.
The chart displays price (line) and volume (bars on a secondary y-axis),
@ -389,7 +409,7 @@ def update_live_html(candles, trade_history, epoch, loss):
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}")
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}")
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
@ -422,7 +442,7 @@ def update_live_html(candles, trade_history, epoch, loss):
</head>
<body>
<div class="chart-container">
<h2>Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}</h2>
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div>
</body>
@ -445,7 +465,6 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.set_xlabel("Time")
ax.set_ylabel("Price")
# Use short format for date and time.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
@ -530,8 +549,7 @@ class BacktestEnvironment:
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL (close trade)
# Use the "close" price to compute reward.
exit_price = next_candle["close"]
exit_price = next_candle["close"] # use close price
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
@ -551,7 +569,9 @@ class BacktestEnvironment:
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
lambda_trade = args.lambda_trade
total_pnl = 0.0
# Load any saved total PnL from training cache:
training_cache = load_training_cache(TRAINING_CACHE_FILE)
total_pnl = training_cache.get("total_pnl", 0.0)
for epoch in range(start_epoch, args.epochs):
env.reset()
loss_accum = 0.0
@ -580,11 +600,18 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
loss_accum += loss.item()
scheduler.step()
epoch_loss = loss_accum / steps
total_pnl += sum(trade["pnl"] for trade in env.trade_history)
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}")
# If no trades occurred during the epoch, multiply the loss by 3.
if len(env.trade_history) == 0:
epoch_loss *= 3
epoch_pnl = sum(trade["pnl"] for trade in env.trade_history)
total_pnl += epoch_pnl
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}")
save_checkpoint(model, optimizer, epoch, loss_accum)
simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss)
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl)
# Update training cache with the new total PnL:
training_cache["total_pnl"] = total_pnl
save_training_cache(TRAINING_CACHE_FILE, training_cache)
# --- Live Plotting (for Live Mode) ---
def live_preview_loop(candles, env):

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
{"total_pnl": 0.0}