From 07047369c911979dfb42632f74baa5f37b364409 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 21:20:38 +0200 Subject: [PATCH] wip --- crypto/brian/index-deep-new.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 98fbfec..e6423eb 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -239,7 +239,7 @@ class BacktestEnvironment: def step(self, action): base_candles = self.candles_dict[self.base_tf] - # Handle end-of-data scenario: return dummy high/low values if needed. + # End-of-data: return dummy high/low targets if self.current_index >= len(base_candles) - 1: current_state = self.get_state(self.current_index) return current_state, 0.0, None, True, 0.0, 0.0 @@ -247,7 +247,6 @@ class BacktestEnvironment: current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) - current_candle = base_candles[self.current_index] next_candle = base_candles[next_index] reward = 0.0 @@ -272,31 +271,31 @@ class BacktestEnvironment: self.current_index = next_index done = (self.current_index >= len(base_candles) - 1) - # Return the high and low of the next candle as the targets. actual_high = next_candle["high"] actual_low = next_candle["low"] return current_state, reward, next_state, done, actual_high, actual_low + def __len__(self): return len(self.candles_dict[self.base_tf]) # --- Enhanced Training Loop --- -def train_on_historical_data(env, model, device, args): +def train_on_historical_data(env, model, device, args, start_epoch=0): optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) - for epoch in range(args.epochs): + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) + for epoch in range(start_epoch, args.epochs): state = env.reset() total_loss = 0 model.train() while True: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) - timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels + timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels pred_high, pred_low = model(state_tensor, timeframe_ids) # Dummy targets from next candle's high/low _, _, next_state, done, actual_high, actual_low = env.step(None) target_high = torch.FloatTensor([actual_high]).to(device) - target_low = torch.FloatTensor([actual_low]).to(device) + target_low = torch.FloatTensor([actual_low]).to(device) high_loss = torch.abs(pred_high - target_high) * 2 - low_loss = torch.abs(pred_low - target_low) * 2 + low_loss = torch.abs(pred_low - target_low) * 2 loss = (high_loss + low_loss).mean() optimizer.zero_grad() loss.backward() @@ -355,6 +354,7 @@ def parse_args(): parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005) + # New flag: if set, we start training from scratch (ignoring any saved checkpoints) parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.') return parser.parse_args() @@ -364,12 +364,11 @@ def random_action(): # --- Main Function --- async def main(): args = parse_args() + # Use GPU if available, else fallback to CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] - input_dim = len(timeframes) * 7 # 7 features per timeframe hidden_dim = 128 - output_dim = 3 # SELL, HOLD, BUY - # Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS. total_channels = NUM_TIMEFRAMES + NUM_INDICATORS model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) @@ -380,6 +379,7 @@ async def main(): return base_tf = "1m" env = BacktestEnvironment(candles_dict, base_tf, timeframes) + start_epoch = 0 if not args.start_fresh: checkpoint = load_best_checkpoint(model) @@ -403,10 +403,10 @@ async def main(): preview_thread.start() print("Starting live trading loop. (Using random actions for simulation.)") while True: - state, reward, next_state, done = env.step(random_action()) + state, reward, next_state, done, _, _ = env.step(random_action()) if done: print("Reached end of simulated data, resetting environment.") - state = env.reset(clear_trade_history=False) + state = env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model)