From 2ec75e66cbbe1c5a9a93a9819d2f7c2841c6f67b Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 18:00:58 +0200 Subject: [PATCH] compile fix --- crypto/brian/index-deep-new.py | 100 ++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 25 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 92fb52e..5259493 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -85,37 +85,49 @@ class TradingModel(nn.Module): ) def forward(self, x, timeframe_ids): - # x shape: [batch_size, num_channels, features] + # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape - # Process each channel + # Process each channel through its branch channel_outs = [] for i in range(num_channels): - channel_out = self.channel_branches[i](x[:,i,:]) + channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) - # Stack and add embeddings - stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] - stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] + # Stack and add timeframe embeddings + stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] + stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] - # Add timeframe embeddings + # Add timeframe embeddings to each channel tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds - # Transformer - src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool() - transformer_out = self.transformer(stacked, src_mask=src_mask.to(x.device)) + # Apply Transformer + src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) + transformer_out = self.transformer(stacked, src_mask=src_mask) - # Attention Pool + # Attention Pooling over channels attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Enhanced Data Processing --- +# Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf +# They must be defined elsewhere in your code. class BacktestEnvironment: + def __init__(self, candles_dict, base_tf, timeframes): + self.candles_dict = candles_dict + self.base_tf = base_tf + self.timeframes = timeframes + self.current_index = 0 # Initialize step pointer + + def reset(self): + self.current_index = 0 + return self.get_state(self.current_index) + def get_state(self, index): - """Returns shape [num_channels, FEATURES_PER_CHANNEL]""" + """Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL].""" state_features = [] base_ts = self.candles_dict[self.base_tf][index]["timestamp"] @@ -127,10 +139,29 @@ class BacktestEnvironment: # Indicator channels (placeholder - implement your indicators) for _ in range(NUM_INDICATORS): - # Add indicator calculation here - state_features.append([0.0]*FEATURES_PER_CHANNEL) + state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) + + def step(self, action): + """ + Advance the environment by one step. + Since this is for backtesting, action isn't used here. + Returns: current candle info, reward, next state, done, actual high, actual low. + """ + # Dummy implementation: you would generate targets based on your backtest logic. + done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2) + current_candle = self.candles_dict[self.base_tf][self.current_index] + # For example, take the next candle's high/low as targets + next_candle = self.candles_dict[self.base_tf][self.current_index + 1] + actual_high = next_candle["high"] + actual_low = next_candle["low"] + self.current_index += 1 + next_state = self.get_state(self.current_index) + return current_candle, 0.0, next_state, done, actual_high, actual_low + + def __len__(self): + return len(self.candles_dict[self.base_tf]) # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args): @@ -143,24 +174,24 @@ def train_on_historical_data(env, model, device, args): model.train() while True: - # Prepare batch + # Prepare batch (here batch size is 1 for simplicity) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) # Forward pass pred_high, pred_low = model(state_tensor, timeframe_ids) - # Get targets from next candle + # Get target values from next candle (dummy targets from environment) _, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) - # Custom loss + # Custom loss: use absolute error scaled by 2 high_loss = torch.abs(pred_high - target_high) * 2 low_loss = torch.abs(pred_low - target_low) * 2 loss = (high_loss + low_loss).mean() - # Backprop + # Backpropagation optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) @@ -176,7 +207,7 @@ def train_on_historical_data(env, model, device, args): print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") save_checkpoint(model, epoch, total_loss) -# --- Mode Handling --- +# --- Mode Handling and Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') @@ -185,31 +216,50 @@ def parse_args(): parser.add_argument('--threshold', type=float, default=0.005) return parser.parse_args() +def load_best_checkpoint(model, best_dir="models/best"): + # Dummy implementation for loading the best checkpoint. + # In real usage, check your saved checkpoints. + print("Loading best checkpoint (dummy implementation)") + # torch.load(...) can be invoked here. + return + +def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"): + # Dummy implementation for saving checkpoints. + print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}") + +# --- Main Function --- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model model = TradingModel( - num_channels=NUM_TIMEFRAMES+NUM_INDICATORS, + num_channels=NUM_TIMEFRAMES + NUM_INDICATORS, num_timeframes=NUM_TIMEFRAMES ).to(device) if args.mode == 'train': - # Initialize environment and train - env = BacktestEnvironment(...) + # Load historical candle data for backtesting + candles_dict = load_candles_cache("candles_cache.json") + if not candles_dict: + print("No historical candle data available for backtesting.") + return + base_tf = "1m" # Base timeframe + timeframes = ["1m", "5m", "15m", "1h", "1d"] + env = BacktestEnvironment(candles_dict, base_tf, timeframes) train_on_historical_data(env, model, device, args) elif args.mode == 'live': # Load model and connect to live data load_best_checkpoint(model) while True: - # Process live data - # Make predictions and execute trades + # Process live data: fetch live candles, make predictions, execute trades + print("Processing live data...") await asyncio.sleep(1) elif args.mode == 'inference': # Load model and run inference load_best_checkpoint(model) - # Generate signals without training + print("Running inference...") + # Add your inference logic here if __name__ == "__main__": asyncio.run(main()) \ No newline at end of file