From 375aebee887b9393c7745abaf4b623bf66d6c72e Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 22:05:41 +0200 Subject: [PATCH] better training algo --- crypto/brian/index-deep-new.py | 147 +++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 51 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 54ea5d9..ab50293 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -32,7 +32,7 @@ CACHE_FILE = "candles_cache.json" # --- Constants --- NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # e.g., 20 technical indicators -FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume +FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): @@ -53,7 +53,7 @@ class PositionalEncoding(nn.Module): class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() - # Create branch for each channel + # Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features) self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), @@ -82,14 +82,14 @@ class TradingModel(nn.Module): nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): - # x: [batch_size, num_channels, FEATURES_PER_CHANNEL] + # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] - stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] + stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] # Add embedding for each channel. tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds @@ -182,7 +182,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False - # Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss. if len(best_models) < 10: add_to_best = True else: @@ -206,7 +205,6 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None - # Choose the checkpoint with the lowest loss. best_loss, best_file = min(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") @@ -217,7 +215,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ - Generate a chart image with buy/sell markers and a dotted line between open and close, + Generate a chart image with buy/sell markers and a dotted line between open/close positions, then embed it in a simple HTML page that auto-refreshes every 10 seconds. """ from io import BytesIO @@ -271,7 +269,7 @@ def update_live_html(candles, trade_history, epoch): # --- Chart Drawing Helpers (used by both live preview and HTML update) --- def update_live_chart(ax, candles, trade_history): """ - Plot the chart with close price, buy and sell markers, and dotted lines joining buy/sell entry/exit. + Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit. """ ax.clear() close_prices = [candle["close"] for candle in candles] @@ -300,13 +298,14 @@ def update_live_chart(ax, candles, trade_history): ax.legend() ax.grid(True) -# --- Forced Action Helper --- +# --- Forced Action & Optimal Hint Helpers --- def get_forced_action(env): """ - Force at least one trade per episode: - - At the very first step, force a BUY (action 2) if no position is open. - - At the penultimate step, if a position is open, force a SELL (action 0). - - Otherwise, default to HOLD (action 1). + When simulating streaming data, we force a trade at strategic moments: + - At the very first step: force BUY. + - At the penultimate step: if a position is open, force SELL. + - Otherwise, default to HOLD. + (The environment will also apply a penalty if the chosen action does not match the optimal hint.) """ total = len(env) if env.current_index == 0: @@ -319,53 +318,98 @@ def get_forced_action(env): else: return 1 # HOLD -# --- Backtest Environment --- +# --- Backtest Environment with Sliding Window and Hints --- class BacktestEnvironment: - def __init__(self, candles_dict, base_tf, timeframes): - self.candles_dict = candles_dict # dict: timeframe -> list of candles + def __init__(self, candles_dict, base_tf, timeframes, window_size=None): + self.candles_dict = candles_dict # full dictionary of timeframe candles self.base_tf = base_tf self.timeframes = timeframes - self.current_index = 0 - self.trade_history = [] - self.position = None + # Use maximum allowed candles for the base timeframe. + self.full_candles = candles_dict[base_tf] + # Determine sliding window size: + if window_size is None: + window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) + self.window_size = window_size + self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price) + self.reset() def reset(self): + # Pick a random sliding window from the full dataset. + self.start_index = random.randint(0, len(self.full_candles) - self.window_size) + self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size] self.current_index = 0 - self.position = None self.trade_history = [] + self.position = None return self.get_state(self.current_index) + def __len__(self): + return self.window_size + def get_state(self, index): + """ + Build state features by taking the candle at the current index for the base timeframe + (from the sliding window) and aligning candles for other timeframes. + Then append zeros for technical indicators. + """ state_features = [] - base_ts = self.candles_dict[self.base_tf][index]["timestamp"] + base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: - aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) - features = get_features_for_tf(self.candles_dict[tf], aligned_idx) + if tf == self.base_tf: + # For base timeframe, use the sliding window candle. + candle = self.candle_window[index] + features = get_features_for_tf([candle], 0) # List of one element + else: + aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) + features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) - + + def compute_optimal_hint(self, horizon=10, threshold=0.005): + """ + Using a lookahead window from the sliding window (future candles) + determine an optimal action hint: + 2: BUY if price is expected to rise at least by threshold. + 0: SELL if expected to drop by threshold. + 1: HOLD otherwise. + """ + base = self.candle_window + if self.current_index >= len(base) - 1: + return 1 # Hold + current_candle = base[self.current_index] + open_price = current_candle["open"] + future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))] + if not future_slice: + return 1 + max_future = max(candle["high"] for candle in future_slice) + min_future = min(candle["low"] for candle in future_slice) + if (max_future - open_price) / open_price >= threshold: + return 2 # BUY + elif (open_price - min_future) / open_price >= threshold: + return 0 # SELL + else: + return 1 # HOLD + def step(self, action): - base_candles = self.candles_dict[self.base_tf] - # End-of-data: return dummy high/low targets - if self.current_index >= len(base_candles) - 1: + base = self.candle_window + if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) - next_candle = base_candles[next_index] + next_candle = base[next_index] reward = 0.0 - # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. + # Trade logic (0: SELL, 1: HOLD, 2: BUY) if self.position is None: - if action == 2: # BUY signal: enter at next candle's open. + if action == 2: # BUY: enter at next candle's open. entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: - if action == 0: # SELL signal: exit at next candle's open. + if action == 0: # SELL: exit at next candle's open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { @@ -379,27 +423,30 @@ class BacktestEnvironment: self.position = None self.current_index = next_index - done = (self.current_index >= len(base_candles) - 1) + done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] actual_low = next_candle["low"] + + # Compute optimal action hint and apply a penalty if action deviates. + optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005) + if action != optimal_hint: + reward -= self.hint_penalty * next_candle["open"] + return current_state, reward, next_state, done, actual_high, actual_low - def __len__(self): - return len(self.candles_dict[self.base_tf]) - # --- Enhanced Training Loop --- -def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, base_candles): +def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): for epoch in range(start_epoch, args.epochs): state = env.reset() - total_loss = 0 + total_loss = 0.0 model.train() while True: - # Use forced action policy to guarantee at least one trade per episode + # Use forced-action policy for trading (guaranteeing at least one trade per episode) action = get_forced_action(env) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) - # Use the forced action in the environment step. + # Use our forced action in the environment step. _, reward, next_state, done, actual_high, actual_low = env.step(action) target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) @@ -418,8 +465,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s epoch_loss = total_loss / len(env) print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") save_checkpoint(model, optimizer, epoch, total_loss) - # Update the live HTML file with the current epoch chart. - update_live_html(base_candles, env.trade_history, epoch+1) + # Update live HTML chart to display the current sliding window + update_live_html(env.candle_window, env.trade_history, epoch+1) # --- Live Plotting Functions (For live mode) --- def live_preview_loop(candles, env): @@ -461,7 +508,8 @@ async def main(): print("No historical candle data available for backtesting.") return base_tf = "1m" - env = BacktestEnvironment(candles_dict, base_tf, timeframes) + # Create the environment with a sliding window (simulate streaming data) + env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None @@ -475,7 +523,6 @@ async def main(): else: print("Starting training from scratch as requested.") - # Create optimizer and scheduler in main optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: @@ -485,8 +532,7 @@ async def main(): print("Loaded optimizer state from checkpoint.") else: print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") - # Pass the base timeframe candles for the live HTML chart update. - train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf]) + train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': load_best_checkpoint(model) @@ -494,22 +540,21 @@ async def main(): if not candles_dict: print("No cached candles available for live preview.") return - env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) - preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True) + env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) + preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() - print("Starting live trading loop. (Using forced action policy for simulation.)") - # Here we use the forced-action policy as in training. + print("Starting live trading loop. (Using forced-action policy for simulation.)") while True: action = get_forced_action(env) state, reward, next_state, done, _, _ = env.step(action) if done: - print("Reached end of simulated data, resetting environment.") + print("Reached end of simulation window, resetting environment.") state = env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") - # Here you can apply a similar forced-action policy or use a learned policy. + # Apply a similar (or learned) policy as needed. else: print("Invalid mode specified.")