training works on CPU

This commit is contained in:
Dobromir Popov 2025-02-04 20:59:24 +02:00
parent 20d6542d2c
commit a25e1eb686

View File

@ -239,14 +239,18 @@ class BacktestEnvironment:
def step(self, action): def step(self, action):
base_candles = self.candles_dict[self.base_tf] base_candles = self.candles_dict[self.base_tf]
# Handle end-of-data scenario: return dummy high/low values if needed.
if self.current_index >= len(base_candles) - 1: if self.current_index >= len(base_candles) - 1:
return self.get_state(self.current_index), 0.0, None, True current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index) current_state = self.get_state(self.current_index)
next_index = self.current_index + 1 next_index = self.current_index + 1
next_state = self.get_state(next_index) next_state = self.get_state(next_index)
current_candle = base_candles[self.current_index] current_candle = base_candles[self.current_index]
next_candle = base_candles[next_index] next_candle = base_candles[next_index]
reward = 0.0 reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None: if self.position is None:
if action == 2: # BUY signal: enter at next candle's open. if action == 2: # BUY signal: enter at next candle's open.
@ -265,10 +269,13 @@ class BacktestEnvironment:
} }
self.trade_history.append(trade) self.trade_history.append(trade)
self.position = None self.position = None
self.current_index = next_index self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1) done = (self.current_index >= len(base_candles) - 1)
return current_state, reward, next_state, done # Return the high and low of the next candle as the targets.
actual_high = next_candle["high"]
actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low
def __len__(self): def __len__(self):
return len(self.candles_dict[self.base_tf]) return len(self.candles_dict[self.base_tf])