reset checkpoint if model fails

This commit is contained in:
Dobromir Popov 2025-02-04 22:23:15 +02:00
parent 39ce152391
commit f32f648bf0

View File

@ -93,10 +93,10 @@ class TradingModel(nn.Module):
channel_out = self.channel_branches[i](x[:, i, :]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
# Notice that with batch_first=True, we want shape [batch, channels, hidden] # With batch_first=True, the expected input is [batch, seq_len, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids) tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch # Expand tf_embeds to match the batch dimension.
# The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True. stacked = stacked + tf_embeds.unsqueeze(0)
transformer_out = self.transformer(stacked) transformer_out = self.transformer(stacked)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
aggregated = (transformer_out * attn_weights).sum(dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1)
@ -223,8 +223,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
new_embed[:old_embed.shape[0]] = old_embed new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed old_state["timeframe_embed.weight"] = new_embed
# (For channel_branches, if the checkpoint has fewer branches than your new model expects, # For channel_branches, if there are missing keys, load_state_dict with strict=False.
# missing branches will be left at their randomly initialized values.)
model.load_state_dict(old_state, strict=False) model.load_state_dict(old_state, strict=False)
return checkpoint return checkpoint
@ -513,10 +512,18 @@ async def main():
if checkpoint is not None: if checkpoint is not None:
optim_state = checkpoint.get("optimizer_state_dict", None) optim_state = checkpoint.get("optimizer_state_dict", None)
if optim_state is not None and "param_groups" in optim_state: if optim_state is not None and "param_groups" in optim_state:
try:
optimizer.load_state_dict(optim_state) optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.") print("Loaded optimizer state from checkpoint.")
except Exception as e:
print("Failed to load optimizer state due to:", e)
print("Deleting all checkpoints and starting fresh.")
# Delete checkpoint files.
for chk_dir in [LAST_DIR, BEST_DIR]:
for f in os.listdir(chk_dir):
os.remove(os.path.join(chk_dir, f))
else: else:
print("No valid optimizer state found; using fresh optimizer state.") print("No valid optimizer state found in checkpoint; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live': elif args.mode == 'live':