diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index d6253a6..2d18ffd 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -93,10 +93,10 @@ class TradingModel(nn.Module): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] - # Notice that with batch_first=True, we want shape [batch, channels, hidden] - tf_embeds = self.timeframe_embed(timeframe_ids) - stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch - # The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True. + # With batch_first=True, the expected input is [batch, seq_len, hidden] + tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden] + # Expand tf_embeds to match the batch dimension. + stacked = stacked + tf_embeds.unsqueeze(0) transformer_out = self.transformer(stacked) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1) @@ -223,8 +223,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): new_embed[:old_embed.shape[0]] = old_embed old_state["timeframe_embed.weight"] = new_embed - # (For channel_branches, if the checkpoint has fewer branches than your new model expects, - # missing branches will be left at their randomly initialized values.) + # For channel_branches, if there are missing keys, load_state_dict with strict=False. model.load_state_dict(old_state, strict=False) return checkpoint @@ -513,10 +512,18 @@ async def main(): if checkpoint is not None: optim_state = checkpoint.get("optimizer_state_dict", None) if optim_state is not None and "param_groups" in optim_state: - optimizer.load_state_dict(optim_state) - print("Loaded optimizer state from checkpoint.") + try: + optimizer.load_state_dict(optim_state) + print("Loaded optimizer state from checkpoint.") + except Exception as e: + print("Failed to load optimizer state due to:", e) + print("Deleting all checkpoints and starting fresh.") + # Delete checkpoint files. + for chk_dir in [LAST_DIR, BEST_DIR]: + for f in os.listdir(chk_dir): + os.remove(os.path.join(chk_dir, f)) else: - print("No valid optimizer state found; using fresh optimizer state.") + print("No valid optimizer state found in checkpoint; using fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live':