diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py
index 4404189..4ca8610 100644
--- a/crypto/brian/index-deep-new.py
+++ b/crypto/brian/index-deep-new.py
@@ -59,6 +59,7 @@ BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
+TRAINING_CACHE_FILE = "training_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
@@ -66,6 +67,25 @@ NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
ORDER_CHANNELS = 1 # One extra channel for order information.
+# --- Training Cache Helpers ---
+def load_training_cache(filename):
+ if os.path.exists(filename):
+ try:
+ with open(filename, "r") as f:
+ cache = json.load(f)
+ print(f"Loaded training cache from {filename}.")
+ return cache
+ except Exception as e:
+ print("Error loading training cache:", e)
+ return {"total_pnl": 0.0}
+
+def save_training_cache(filename, cache):
+ try:
+ with open(filename, "w") as f:
+ json.dump(cache, f)
+ except Exception as e:
+ print("Error saving training cache:", e)
+
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
@@ -376,7 +396,7 @@ def simulate_trades(model, env, device, args):
break
# --- Live HTML Chart Update (with Volume and Loss) ---
-def update_live_html(candles, trade_history, epoch, loss):
+def update_live_html(candles, trade_history, epoch, loss, total_pnl):
"""
Generate an HTML page with a live chart.
The chart displays price (line) and volume (bars on a secondary y-axis),
@@ -389,7 +409,7 @@ def update_live_html(candles, trade_history, epoch, loss):
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
- ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}")
+ ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}")
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
@@ -422,7 +442,7 @@ def update_live_html(candles, trade_history, epoch, loss):
-
Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}
+
Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}
@@ -445,7 +465,6 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.set_xlabel("Time")
ax.set_ylabel("Price")
- # Use short format for date and time.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
@@ -530,8 +549,7 @@ class BacktestEnvironment:
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL (close trade)
- # Use the "close" price to compute reward.
- exit_price = next_candle["close"]
+ exit_price = next_candle["close"] # use close price
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
@@ -551,7 +569,9 @@ class BacktestEnvironment:
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
lambda_trade = args.lambda_trade
- total_pnl = 0.0
+ # Load any saved total PnL from training cache:
+ training_cache = load_training_cache(TRAINING_CACHE_FILE)
+ total_pnl = training_cache.get("total_pnl", 0.0)
for epoch in range(start_epoch, args.epochs):
env.reset()
loss_accum = 0.0
@@ -580,11 +600,18 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
loss_accum += loss.item()
scheduler.step()
epoch_loss = loss_accum / steps
- total_pnl += sum(trade["pnl"] for trade in env.trade_history)
- print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}")
+ # If no trades occurred during the epoch, multiply the loss by 3.
+ if len(env.trade_history) == 0:
+ epoch_loss *= 3
+ epoch_pnl = sum(trade["pnl"] for trade in env.trade_history)
+ total_pnl += epoch_pnl
+ print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}")
save_checkpoint(model, optimizer, epoch, loss_accum)
simulate_trades(model, env, device, args)
- update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss)
+ update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl)
+ # Update training cache with the new total PnL:
+ training_cache["total_pnl"] = total_pnl
+ save_training_cache(TRAINING_CACHE_FILE, training_cache)
# --- Live Plotting (for Live Mode) ---
def live_preview_loop(candles, env):
diff --git a/crypto/brian/live_chart.html b/crypto/brian/live_chart.html
index d9e07f6..e8747bc 100644
--- a/crypto/brian/live_chart.html
+++ b/crypto/brian/live_chart.html
@@ -3,8 +3,8 @@
-
- Live Trading Chart - Epoch 100
+
+ Live Trading Chart - Epoch 376