try to fix total pnl
This commit is contained in:
parent
a04a2d7677
commit
d29cc312fd
@ -59,6 +59,7 @@ BEST_DIR = os.path.join("models", "best")
|
||||
os.makedirs(LAST_DIR, exist_ok=True)
|
||||
os.makedirs(BEST_DIR, exist_ok=True)
|
||||
CACHE_FILE = "candles_cache.json"
|
||||
TRAINING_CACHE_FILE = "training_cache.json"
|
||||
|
||||
# --- Constants ---
|
||||
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
|
||||
@ -66,6 +67,25 @@ NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
||||
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
|
||||
ORDER_CHANNELS = 1 # One extra channel for order information.
|
||||
|
||||
# --- Training Cache Helpers ---
|
||||
def load_training_cache(filename):
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
cache = json.load(f)
|
||||
print(f"Loaded training cache from {filename}.")
|
||||
return cache
|
||||
except Exception as e:
|
||||
print("Error loading training cache:", e)
|
||||
return {"total_pnl": 0.0}
|
||||
|
||||
def save_training_cache(filename, cache):
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(cache, f)
|
||||
except Exception as e:
|
||||
print("Error saving training cache:", e)
|
||||
|
||||
# --- Positional Encoding Module ---
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
@ -376,7 +396,7 @@ def simulate_trades(model, env, device, args):
|
||||
break
|
||||
|
||||
# --- Live HTML Chart Update (with Volume and Loss) ---
|
||||
def update_live_html(candles, trade_history, epoch, loss):
|
||||
def update_live_html(candles, trade_history, epoch, loss, total_pnl):
|
||||
"""
|
||||
Generate an HTML page with a live chart.
|
||||
The chart displays price (line) and volume (bars on a secondary y-axis),
|
||||
@ -389,7 +409,7 @@ def update_live_html(candles, trade_history, epoch, loss):
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
update_live_chart(ax, candles, trade_history)
|
||||
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
|
||||
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}")
|
||||
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}")
|
||||
buf = BytesIO()
|
||||
fig.savefig(buf, format='png')
|
||||
plt.close(fig)
|
||||
@ -422,7 +442,7 @@ def update_live_html(candles, trade_history, epoch, loss):
|
||||
</head>
|
||||
<body>
|
||||
<div class="chart-container">
|
||||
<h2>Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}</h2>
|
||||
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}</h2>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
|
||||
</div>
|
||||
</body>
|
||||
@ -445,7 +465,6 @@ def update_live_chart(ax, candles, trade_history):
|
||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||
ax.set_xlabel("Time")
|
||||
ax.set_ylabel("Price")
|
||||
# Use short format for date and time.
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
|
||||
ax2 = ax.twinx()
|
||||
volumes = [candle["volume"] for candle in candles]
|
||||
@ -530,8 +549,7 @@ class BacktestEnvironment:
|
||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL (close trade)
|
||||
# Use the "close" price to compute reward.
|
||||
exit_price = next_candle["close"]
|
||||
exit_price = next_candle["close"] # use close price
|
||||
reward = exit_price - self.position["entry_price"]
|
||||
trade = {
|
||||
"entry_index": self.position["entry_index"],
|
||||
@ -551,7 +569,9 @@ class BacktestEnvironment:
|
||||
# --- Enhanced Training Loop ---
|
||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||
lambda_trade = args.lambda_trade
|
||||
total_pnl = 0.0
|
||||
# Load any saved total PnL from training cache:
|
||||
training_cache = load_training_cache(TRAINING_CACHE_FILE)
|
||||
total_pnl = training_cache.get("total_pnl", 0.0)
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
env.reset()
|
||||
loss_accum = 0.0
|
||||
@ -580,11 +600,18 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
loss_accum += loss.item()
|
||||
scheduler.step()
|
||||
epoch_loss = loss_accum / steps
|
||||
total_pnl += sum(trade["pnl"] for trade in env.trade_history)
|
||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}")
|
||||
# If no trades occurred during the epoch, multiply the loss by 3.
|
||||
if len(env.trade_history) == 0:
|
||||
epoch_loss *= 3
|
||||
epoch_pnl = sum(trade["pnl"] for trade in env.trade_history)
|
||||
total_pnl += epoch_pnl
|
||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}")
|
||||
save_checkpoint(model, optimizer, epoch, loss_accum)
|
||||
simulate_trades(model, env, device, args)
|
||||
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss)
|
||||
update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl)
|
||||
# Update training cache with the new total PnL:
|
||||
training_cache["total_pnl"] = total_pnl
|
||||
save_training_cache(TRAINING_CACHE_FILE, training_cache)
|
||||
|
||||
# --- Live Plotting (for Live Mode) ---
|
||||
def live_preview_loop(candles, env):
|
||||
|
File diff suppressed because one or more lines are too long
1
crypto/brian/training_cache.json
Normal file
1
crypto/brian/training_cache.json
Normal file
@ -0,0 +1 @@
|
||||
{"total_pnl": 0.0}
|
Loading…
x
Reference in New Issue
Block a user