diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 188071e..7eecd9c 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -236,17 +236,21 @@ class BacktestEnvironment: for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) - + def step(self, action): base_candles = self.candles_dict[self.base_tf] + # Handle end-of-data scenario: return dummy high/low values if needed. if self.current_index >= len(base_candles) - 1: - return self.get_state(self.current_index), 0.0, None, True + current_state = self.get_state(self.current_index) + return current_state, 0.0, None, True, 0.0, 0.0 + current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) current_candle = base_candles[self.current_index] next_candle = base_candles[next_index] reward = 0.0 + # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. if self.position is None: if action == 2: # BUY signal: enter at next candle's open. @@ -265,10 +269,13 @@ class BacktestEnvironment: } self.trade_history.append(trade) self.position = None + self.current_index = next_index done = (self.current_index >= len(base_candles) - 1) - return current_state, reward, next_state, done - + # Return the high and low of the next candle as the targets. + actual_high = next_candle["high"] + actual_low = next_candle["low"] + return current_state, reward, next_state, done, actual_high, actual_low def __len__(self): return len(self.candles_dict[self.base_tf])