This commit is contained in:
Dobromir Popov 2025-02-04 21:20:38 +02:00
parent 070d58f2bf
commit 07047369c9

View File

@ -239,7 +239,7 @@ class BacktestEnvironment:
def step(self, action): def step(self, action):
base_candles = self.candles_dict[self.base_tf] base_candles = self.candles_dict[self.base_tf]
# Handle end-of-data scenario: return dummy high/low values if needed. # End-of-data: return dummy high/low targets
if self.current_index >= len(base_candles) - 1: if self.current_index >= len(base_candles) - 1:
current_state = self.get_state(self.current_index) current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0 return current_state, 0.0, None, True, 0.0, 0.0
@ -247,7 +247,6 @@ class BacktestEnvironment:
current_state = self.get_state(self.current_index) current_state = self.get_state(self.current_index)
next_index = self.current_index + 1 next_index = self.current_index + 1
next_state = self.get_state(next_index) next_state = self.get_state(next_index)
current_candle = base_candles[self.current_index]
next_candle = base_candles[next_index] next_candle = base_candles[next_index]
reward = 0.0 reward = 0.0
@ -272,24 +271,24 @@ class BacktestEnvironment:
self.current_index = next_index self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1) done = (self.current_index >= len(base_candles) - 1)
# Return the high and low of the next candle as the targets.
actual_high = next_candle["high"] actual_high = next_candle["high"]
actual_low = next_candle["low"] actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low return current_state, reward, next_state, done, actual_high, actual_low
def __len__(self): def __len__(self):
return len(self.candles_dict[self.base_tf]) return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop --- # --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args): def train_on_historical_data(env, model, device, args, start_epoch=0):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
for epoch in range(args.epochs): for epoch in range(start_epoch, args.epochs):
state = env.reset() state = env.reset()
total_loss = 0 total_loss = 0
model.train() model.train()
while True: while True:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Dummy targets from next candle's high/low # Dummy targets from next candle's high/low
_, _, next_state, done, actual_high, actual_low = env.step(None) _, _, next_state, done, actual_high, actual_low = env.step(None)
@ -355,6 +354,7 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005) parser.add_argument('--threshold', type=float, default=0.005)
# New flag: if set, we start training from scratch (ignoring any saved checkpoints)
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.') parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.')
return parser.parse_args() return parser.parse_args()
@ -364,12 +364,11 @@ def random_action():
# --- Main Function --- # --- Main Function ---
async def main(): async def main():
args = parse_args() args = parse_args()
# Use GPU if available, else fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"] timeframes = ["1m", "5m", "15m", "1h", "1d"]
input_dim = len(timeframes) * 7 # 7 features per timeframe
hidden_dim = 128 hidden_dim = 128
output_dim = 3 # SELL, HOLD, BUY
# Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
@ -380,6 +379,7 @@ async def main():
return return
base_tf = "1m" base_tf = "1m"
env = BacktestEnvironment(candles_dict, base_tf, timeframes) env = BacktestEnvironment(candles_dict, base_tf, timeframes)
start_epoch = 0 start_epoch = 0
if not args.start_fresh: if not args.start_fresh:
checkpoint = load_best_checkpoint(model) checkpoint = load_best_checkpoint(model)
@ -403,10 +403,10 @@ async def main():
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (Using random actions for simulation.)") print("Starting live trading loop. (Using random actions for simulation.)")
while True: while True:
state, reward, next_state, done = env.step(random_action()) state, reward, next_state, done, _, _ = env.step(random_action())
if done: if done:
print("Reached end of simulated data, resetting environment.") print("Reached end of simulated data, resetting environment.")
state = env.reset(clear_trade_history=False) state = env.reset()
await asyncio.sleep(1) await asyncio.sleep(1)
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)