From 39ce1523913c5c7582f5d57ade4815e700701d50 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 22:16:39 +0200 Subject: [PATCH] more efficient traning --- crypto/brian/index-deep-new.py | 97 +++++++++++++--------------------- 1 file changed, 36 insertions(+), 61 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 94bf4c6..d6253a6 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -68,9 +68,10 @@ class TradingModel(nn.Module): # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) + # Set batch_first=True to avoid the nested tensor warning. encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, - dropout=0.1, activation='gelu', batch_first=False + dropout=0.1, activation='gelu', batch_first=True ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) @@ -92,13 +93,13 @@ class TradingModel(nn.Module): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] - stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] - tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) - stacked = stacked + tf_embeds - src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) - transformer_out = self.transformer(stacked, mask=src_mask) - attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) - aggregated = (transformer_out * attn_weights).sum(dim=0) + # Notice that with batch_first=True, we want shape [batch, channels, hidden] + tf_embeds = self.timeframe_embed(timeframe_ids) + stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch + # The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True. + transformer_out = self.transformer(stacked) + attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) + aggregated = (transformer_out * attn_weights).sum(dim=1) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Technical Indicator Helpers --- @@ -210,7 +211,21 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) - model.load_state_dict(checkpoint["model_state_dict"]) + old_state = checkpoint["model_state_dict"] + new_state = model.state_dict() + + # Fix the size mismatch for timeframe_embed.weight. + if "timeframe_embed.weight" in old_state: + old_embed = old_state["timeframe_embed.weight"] + new_embed = new_state["timeframe_embed.weight"] + if old_embed.shape[0] < new_embed.shape[0]: + # Copy the available rows and keep the remaining as initialized. + new_embed[:old_embed.shape[0]] = old_embed + old_state["timeframe_embed.weight"] = new_embed + + # (For channel_branches, if the checkpoint has fewer branches than your new model expects, + # missing branches will be left at their randomly initialized values.) + model.load_state_dict(old_state, strict=False) return checkpoint # --- Live HTML Chart Update --- @@ -283,24 +298,13 @@ def update_live_chart(ax, candles, trade_history): ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) # Format x-axis date labels. ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) - # Plot each trade. - buy_label_added = False - sell_label_added = False for trade in trade_history: entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"]) exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"]) in_price = trade["entry_price"] out_price = trade["exit_price"] - if not buy_label_added: - ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") - buy_label_added = True - else: - ax.plot(entry_time, in_price, marker="^", color="green", markersize=10) - if not sell_label_added: - ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") - sell_label_added = True - else: - ax.plot(exit_time, out_price, marker="v", color="red", markersize=10) + ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") + ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") ax.set_xlabel("Time") ax.set_ylabel("Price") @@ -325,7 +329,6 @@ def simulate_trades(model, env, device, args): pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high = pred_high.item() pred_low = pred_low.item() - # Simple decision rule based on predicted move. if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: action = 2 # BUY elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: @@ -349,7 +352,6 @@ class BacktestEnvironment: self.reset() def reset(self): - # Randomly select a sliding window from the full dataset. self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.current_index = 0 @@ -361,12 +363,6 @@ class BacktestEnvironment: return self.window_size def get_order_features(self, index): - """ - Returns a list of 7 features for the order channel. - If an order is open, the first element is 1.0 and the second is the normalized difference: - (current open - entry_price) / current open. - Otherwise, returns zeros. - """ candle = self.candle_window[index] if self.position is None: return [0.0] * FEATURES_PER_CHANNEL @@ -376,13 +372,6 @@ class BacktestEnvironment: return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) def get_state(self, index): - """ - Build state features from: - - For each timeframe: features from the aligned candle. - - One extra channel: current order information. - - NUM_INDICATORS channels of zeros. - Each channel is a vector of length FEATURES_PER_CHANNEL. - """ state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: @@ -393,36 +382,27 @@ class BacktestEnvironment: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) - # Append order channel. order_features = self.get_order_features(index) state_features.append(order_features) - # Append technical indicator channels. for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): - """ - Execute one step in the environment: - - action: 0 => SELL, 1 => HOLD, 2 => BUY. - - Trades are recorded when a BUY is followed by a SELL. - """ base = self.candle_window if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 - current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) next_candle = base[next_index] reward = 0.0 - if self.position is None: - if action == 2: # BUY signal: enter at next open. + if action == 2: self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: - if action == 0: # SELL signal: exit at next open. + if action == 0: exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { @@ -434,7 +414,6 @@ class BacktestEnvironment: } self.trade_history.append(trade) self.position = None - self.current_index = next_index done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] @@ -443,11 +422,11 @@ class BacktestEnvironment: # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): - lambda_trade = args.lambda_trade # Weight for the surrogate profit loss. + lambda_trade = args.lambda_trade for epoch in range(start_epoch, args.epochs): - env.reset() # Resets the sliding window. + env.reset() loss_accum = 0.0 - steps = len(env) - 1 # We assume steps over consecutive candle pairs. + steps = len(env) - 1 for i in range(steps): state = env.get_state(i) current_open = env.candle_window[i]["open"] @@ -456,14 +435,11 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) - # Prediction loss (L1 error). L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ torch.abs(pred_low - torch.tensor(actual_low, device=device)) - # Surrogate profit loss. - profit_buy = pred_high - current_open # potential long gain - profit_sell = current_open - pred_low # potential short gain + profit_buy = pred_high - current_open + profit_sell = current_open - pred_low L_trade = - torch.max(profit_buy, profit_sell) - # Additional penalty if no strong signal is produced. current_open_tensor = torch.tensor(current_open, device=device) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) @@ -497,7 +473,7 @@ def parse_args(): parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") - parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).") + parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") return parser.parse_args() @@ -511,7 +487,6 @@ async def main(): print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] hidden_dim = 128 - # Total channels: NUM_TIMEFRAMES + 1 (order info) + NUM_INDICATORS. total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) @@ -541,7 +516,7 @@ async def main(): optimizer.load_state_dict(optim_state) print("Loaded optimizer state from checkpoint.") else: - print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") + print("No valid optimizer state found; using fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': @@ -576,7 +551,7 @@ async def main(): elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") - # Your inference logic goes here. + # Inference logic goes here. else: print("Invalid mode specified.")