From 10ff22eb42eb00df9b27dbe17b65f6f8bbe3fcf4 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 21:52:11 +0200 Subject: [PATCH] forced actions --- .gitignore | 1 + crypto/brian/index-deep-new.py | 44 +++++++++++++++++++++++++--------- crypto/brian/live_chart.html | 6 ++--- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index bd91baa..39028d4 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ app_data.db crypto/sol/.vs/* crypto/brian/models/best/* crypto/brian/models/last/* +crypto/brian/live_chart.html diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 3107b6b..54ea5d9 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -186,7 +186,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B if len(best_models) < 10: add_to_best = True else: - # The worst saved checkpoint will have the highest loss. worst_loss, worst_file = max(best_models, key=lambda x: x[0]) if loss < worst_loss: add_to_best = True @@ -219,7 +218,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): def update_live_html(candles, trade_history, epoch): """ Generate a chart image with buy/sell markers and a dotted line between open and close, - then embed it in a simple HTML page that auto-refreshes. + then embed it in a simple HTML page that auto-refreshes every 10 seconds. """ from io import BytesIO import base64 @@ -301,6 +300,25 @@ def update_live_chart(ax, candles, trade_history): ax.legend() ax.grid(True) +# --- Forced Action Helper --- +def get_forced_action(env): + """ + Force at least one trade per episode: + - At the very first step, force a BUY (action 2) if no position is open. + - At the penultimate step, if a position is open, force a SELL (action 0). + - Otherwise, default to HOLD (action 1). + """ + total = len(env) + if env.current_index == 0: + return 2 # BUY + elif env.current_index >= total - 2: + if env.position is not None: + return 0 # SELL + else: + return 1 # HOLD + else: + return 1 # HOLD + # --- Backtest Environment --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes): @@ -343,11 +361,11 @@ class BacktestEnvironment: # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. if self.position is None: - if action == 2: # BUY: enter at next candle's open. + if action == 2: # BUY signal: enter at next candle's open. entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: - if action == 0: # SELL: exit at next candle's open. + if action == 0: # SELL signal: exit at next candle's open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { @@ -376,11 +394,13 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s total_loss = 0 model.train() while True: + # Use forced action policy to guarantee at least one trade per episode + action = get_forced_action(env) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) - # Get targets from environment (dummy high/low from next candle) - _, _, next_state, done, actual_high, actual_low = env.step(None) + # Use the forced action in the environment step. + _, reward, next_state, done, actual_high, actual_low = env.step(action) target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) high_loss = torch.abs(pred_high - target_high) * 2 @@ -398,7 +418,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s epoch_loss = total_loss / len(env) print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") save_checkpoint(model, optimizer, epoch, total_loss) - # Update the live HTML file with the current epoch chart + # Update the live HTML file with the current epoch chart. update_live_html(base_candles, env.trade_history, epoch+1) # --- Live Plotting Functions (For live mode) --- @@ -465,7 +485,7 @@ async def main(): print("Loaded optimizer state from checkpoint.") else: print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") - # Pass base candles from the base timeframe for HTML chart updates. + # Pass the base timeframe candles for the live HTML chart update. train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf]) elif args.mode == 'live': @@ -477,9 +497,11 @@ async def main(): env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True) preview_thread.start() - print("Starting live trading loop. (Using random actions for simulation.)") + print("Starting live trading loop. (Using forced action policy for simulation.)") + # Here we use the forced-action policy as in training. while True: - state, reward, next_state, done, _, _ = env.step(random_action()) + action = get_forced_action(env) + state, reward, next_state, done, _, _ = env.step(action) if done: print("Reached end of simulated data, resetting environment.") state = env.reset() @@ -487,7 +509,7 @@ async def main(): elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") - # Your inference logic goes here. + # Here you can apply a similar forced-action policy or use a learned policy. else: print("Invalid mode specified.") diff --git a/crypto/brian/live_chart.html b/crypto/brian/live_chart.html index 9cacafb..0512b35 100644 --- a/crypto/brian/live_chart.html +++ b/crypto/brian/live_chart.html @@ -4,7 +4,7 @@ - Live Trading Chart - Epoch 9 + Live Trading Chart - Epoch 14