fix
This commit is contained in:
parent
f220e5fc4d
commit
ee10274586
@ -90,7 +90,7 @@ class TradingModel(nn.Module):
|
|||||||
channel_outs.append(channel_out)
|
channel_outs.append(channel_out)
|
||||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||||
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
||||||
# Add an embedding for each channel.
|
# Add embedding for each channel.
|
||||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||||
stacked = stacked + tf_embeds
|
stacked = stacked + tf_embeds
|
||||||
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
||||||
@ -353,7 +353,7 @@ def parse_args():
|
|||||||
parser.add_argument('--epochs', type=int, default=100)
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--threshold', type=float, default=0.005)
|
parser.add_argument('--threshold', type=float, default=0.005)
|
||||||
# When set, training starts from scratch (ignoring saved checkpoints)
|
# If set, training starts from scratch (ignoring saved checkpoints)
|
||||||
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
|
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -395,8 +395,12 @@ async def main():
|
|||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
# Restore optimizer state for a true resume
|
optim_state = checkpoint.get("optimizer_state_dict", None)
|
||||||
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {}))
|
if optim_state is not None and "param_groups" in optim_state:
|
||||||
|
optimizer.load_state_dict(optim_state)
|
||||||
|
print("Loaded optimizer state from checkpoint.")
|
||||||
|
else:
|
||||||
|
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
|
||||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||||
|
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user