training works on CPU
This commit is contained in:
parent
20d6542d2c
commit
a25e1eb686
@ -239,14 +239,18 @@ class BacktestEnvironment:
|
||||
|
||||
def step(self, action):
|
||||
base_candles = self.candles_dict[self.base_tf]
|
||||
# Handle end-of-data scenario: return dummy high/low values if needed.
|
||||
if self.current_index >= len(base_candles) - 1:
|
||||
return self.get_state(self.current_index), 0.0, None, True
|
||||
current_state = self.get_state(self.current_index)
|
||||
return current_state, 0.0, None, True, 0.0, 0.0
|
||||
|
||||
current_state = self.get_state(self.current_index)
|
||||
next_index = self.current_index + 1
|
||||
next_state = self.get_state(next_index)
|
||||
current_candle = base_candles[self.current_index]
|
||||
next_candle = base_candles[next_index]
|
||||
reward = 0.0
|
||||
|
||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||
if self.position is None:
|
||||
if action == 2: # BUY signal: enter at next candle's open.
|
||||
@ -265,10 +269,13 @@ class BacktestEnvironment:
|
||||
}
|
||||
self.trade_history.append(trade)
|
||||
self.position = None
|
||||
|
||||
self.current_index = next_index
|
||||
done = (self.current_index >= len(base_candles) - 1)
|
||||
return current_state, reward, next_state, done
|
||||
|
||||
# Return the high and low of the next candle as the targets.
|
||||
actual_high = next_candle["high"]
|
||||
actual_low = next_candle["low"]
|
||||
return current_state, reward, next_state, done, actual_high, actual_low
|
||||
def __len__(self):
|
||||
return len(self.candles_dict[self.base_tf])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user