reset checkpoint if model fails
This commit is contained in:
parent
39ce152391
commit
f32f648bf0
@ -93,10 +93,10 @@ class TradingModel(nn.Module):
|
||||
channel_out = self.channel_branches[i](x[:, i, :])
|
||||
channel_outs.append(channel_out)
|
||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||
# Notice that with batch_first=True, we want shape [batch, channels, hidden]
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids)
|
||||
stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch
|
||||
# The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True.
|
||||
# With batch_first=True, the expected input is [batch, seq_len, hidden]
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
|
||||
# Expand tf_embeds to match the batch dimension.
|
||||
stacked = stacked + tf_embeds.unsqueeze(0)
|
||||
transformer_out = self.transformer(stacked)
|
||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
|
||||
aggregated = (transformer_out * attn_weights).sum(dim=1)
|
||||
@ -223,8 +223,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
new_embed[:old_embed.shape[0]] = old_embed
|
||||
old_state["timeframe_embed.weight"] = new_embed
|
||||
|
||||
# (For channel_branches, if the checkpoint has fewer branches than your new model expects,
|
||||
# missing branches will be left at their randomly initialized values.)
|
||||
# For channel_branches, if there are missing keys, load_state_dict with strict=False.
|
||||
model.load_state_dict(old_state, strict=False)
|
||||
return checkpoint
|
||||
|
||||
@ -513,10 +512,18 @@ async def main():
|
||||
if checkpoint is not None:
|
||||
optim_state = checkpoint.get("optimizer_state_dict", None)
|
||||
if optim_state is not None and "param_groups" in optim_state:
|
||||
try:
|
||||
optimizer.load_state_dict(optim_state)
|
||||
print("Loaded optimizer state from checkpoint.")
|
||||
except Exception as e:
|
||||
print("Failed to load optimizer state due to:", e)
|
||||
print("Deleting all checkpoints and starting fresh.")
|
||||
# Delete checkpoint files.
|
||||
for chk_dir in [LAST_DIR, BEST_DIR]:
|
||||
for f in os.listdir(chk_dir):
|
||||
os.remove(os.path.join(chk_dir, f))
|
||||
else:
|
||||
print("No valid optimizer state found; using fresh optimizer state.")
|
||||
print("No valid optimizer state found in checkpoint; using fresh optimizer state.")
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||
|
||||
elif args.mode == 'live':
|
||||
|
Loading…
x
Reference in New Issue
Block a user