This commit is contained in:
Dobromir Popov 2025-02-04 21:25:38 +02:00
parent f220e5fc4d
commit ee10274586

View File

@ -90,7 +90,7 @@ class TradingModel(nn.Module):
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Add an embedding for each channel. # Add embedding for each channel.
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
@ -353,7 +353,7 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005) parser.add_argument('--threshold', type=float, default=0.005)
# When set, training starts from scratch (ignoring saved checkpoints) # If set, training starts from scratch (ignoring saved checkpoints)
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.') parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
return parser.parse_args() return parser.parse_args()
@ -395,8 +395,12 @@ async def main():
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None: if checkpoint is not None:
# Restore optimizer state for a true resume optim_state = checkpoint.get("optimizer_state_dict", None)
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {})) if optim_state is not None and "param_groups" in optim_state:
optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.")
else:
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live': elif args.mode == 'live':