From ee102745866e6c8bd160a86ef9d859ec0f9b1a10 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 21:25:38 +0200 Subject: [PATCH] fix --- crypto/brian/index-deep-new.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 855b6b0..c90e297 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -90,7 +90,7 @@ class TradingModel(nn.Module): channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] - # Add an embedding for each channel. + # Add embedding for each channel. tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005) - # When set, training starts from scratch (ignoring saved checkpoints) + # If set, training starts from scratch (ignoring saved checkpoints) parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.') return parser.parse_args() @@ -395,8 +395,12 @@ async def main(): optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: - # Restore optimizer state for a true resume - optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {})) + optim_state = checkpoint.get("optimizer_state_dict", None) + if optim_state is not None and "param_groups" in optim_state: + optimizer.load_state_dict(optim_state) + print("Loaded optimizer state from checkpoint.") + else: + print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live':