diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 4404189..4ca8610 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -59,6 +59,7 @@ BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" +TRAINING_CACHE_FILE = "training_cache.json" # --- Constants --- NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"] @@ -66,6 +67,25 @@ NUM_INDICATORS = 20 # e.g., 20 technical indicators FEATURES_PER_CHANNEL = 7 # Each channel has 7 features. ORDER_CHANNELS = 1 # One extra channel for order information. +# --- Training Cache Helpers --- +def load_training_cache(filename): + if os.path.exists(filename): + try: + with open(filename, "r") as f: + cache = json.load(f) + print(f"Loaded training cache from {filename}.") + return cache + except Exception as e: + print("Error loading training cache:", e) + return {"total_pnl": 0.0} + +def save_training_cache(filename, cache): + try: + with open(filename, "w") as f: + json.dump(cache, f) + except Exception as e: + print("Error saving training cache:", e) + # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): @@ -376,7 +396,7 @@ def simulate_trades(model, env, device, args): break # --- Live HTML Chart Update (with Volume and Loss) --- -def update_live_html(candles, trade_history, epoch, loss): +def update_live_html(candles, trade_history, epoch, loss, total_pnl): """ Generate an HTML page with a live chart. The chart displays price (line) and volume (bars on a secondary y-axis), @@ -389,7 +409,7 @@ def update_live_html(candles, trade_history, epoch, loss): fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) - ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {epoch_pnl:.2f}") + ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}") buf = BytesIO() fig.savefig(buf, format='png') plt.close(fig) @@ -422,7 +442,7 @@ def update_live_html(candles, trade_history, epoch, loss):
-

Epoch {epoch} | Loss: {loss:.4f} | Total PnL: {epoch_pnl:.2f}

+

Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}

Live Chart
@@ -445,7 +465,6 @@ def update_live_chart(ax, candles, trade_history): ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.set_xlabel("Time") ax.set_ylabel("Price") - # Use short format for date and time. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) ax2 = ax.twinx() volumes = [candle["volume"] for candle in candles] @@ -530,8 +549,7 @@ class BacktestEnvironment: self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: if action == 0: # SELL (close trade) - # Use the "close" price to compute reward. - exit_price = next_candle["close"] + exit_price = next_candle["close"] # use close price reward = exit_price - self.position["entry_price"] trade = { "entry_index": self.position["entry_index"], @@ -551,7 +569,9 @@ class BacktestEnvironment: # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): lambda_trade = args.lambda_trade - total_pnl = 0.0 + # Load any saved total PnL from training cache: + training_cache = load_training_cache(TRAINING_CACHE_FILE) + total_pnl = training_cache.get("total_pnl", 0.0) for epoch in range(start_epoch, args.epochs): env.reset() loss_accum = 0.0 @@ -580,11 +600,18 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s loss_accum += loss.item() scheduler.step() epoch_loss = loss_accum / steps - total_pnl += sum(trade["pnl"] for trade in env.trade_history) - print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Total PnL: {total_pnl:.2f}") + # If no trades occurred during the epoch, multiply the loss by 3. + if len(env.trade_history) == 0: + epoch_loss *= 3 + epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) + total_pnl += epoch_pnl + print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}") save_checkpoint(model, optimizer, epoch, loss_accum) simulate_trades(model, env, device, args) - update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss) + update_live_html(env.candle_window, env.trade_history, epoch+1, epoch_loss, total_pnl) + # Update training cache with the new total PnL: + training_cache["total_pnl"] = total_pnl + save_training_cache(TRAINING_CACHE_FILE, training_cache) # --- Live Plotting (for Live Mode) --- def live_preview_loop(candles, env): diff --git a/crypto/brian/live_chart.html b/crypto/brian/live_chart.html index d9e07f6..e8747bc 100644 --- a/crypto/brian/live_chart.html +++ b/crypto/brian/live_chart.html @@ -3,8 +3,8 @@ - - Live Trading Chart - Epoch 100 + + Live Trading Chart - Epoch 376