reset checkpoint if model fails
This commit is contained in:
parent
39ce152391
commit
f32f648bf0
@ -93,10 +93,10 @@ class TradingModel(nn.Module):
|
|||||||
channel_out = self.channel_branches[i](x[:, i, :])
|
channel_out = self.channel_branches[i](x[:, i, :])
|
||||||
channel_outs.append(channel_out)
|
channel_outs.append(channel_out)
|
||||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||||
# Notice that with batch_first=True, we want shape [batch, channels, hidden]
|
# With batch_first=True, the expected input is [batch, seq_len, hidden]
|
||||||
tf_embeds = self.timeframe_embed(timeframe_ids)
|
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
|
||||||
stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch
|
# Expand tf_embeds to match the batch dimension.
|
||||||
# The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True.
|
stacked = stacked + tf_embeds.unsqueeze(0)
|
||||||
transformer_out = self.transformer(stacked)
|
transformer_out = self.transformer(stacked)
|
||||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
|
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
|
||||||
aggregated = (transformer_out * attn_weights).sum(dim=1)
|
aggregated = (transformer_out * attn_weights).sum(dim=1)
|
||||||
@ -223,8 +223,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
|||||||
new_embed[:old_embed.shape[0]] = old_embed
|
new_embed[:old_embed.shape[0]] = old_embed
|
||||||
old_state["timeframe_embed.weight"] = new_embed
|
old_state["timeframe_embed.weight"] = new_embed
|
||||||
|
|
||||||
# (For channel_branches, if the checkpoint has fewer branches than your new model expects,
|
# For channel_branches, if there are missing keys, load_state_dict with strict=False.
|
||||||
# missing branches will be left at their randomly initialized values.)
|
|
||||||
model.load_state_dict(old_state, strict=False)
|
model.load_state_dict(old_state, strict=False)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
@ -513,10 +512,18 @@ async def main():
|
|||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
optim_state = checkpoint.get("optimizer_state_dict", None)
|
optim_state = checkpoint.get("optimizer_state_dict", None)
|
||||||
if optim_state is not None and "param_groups" in optim_state:
|
if optim_state is not None and "param_groups" in optim_state:
|
||||||
optimizer.load_state_dict(optim_state)
|
try:
|
||||||
print("Loaded optimizer state from checkpoint.")
|
optimizer.load_state_dict(optim_state)
|
||||||
|
print("Loaded optimizer state from checkpoint.")
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to load optimizer state due to:", e)
|
||||||
|
print("Deleting all checkpoints and starting fresh.")
|
||||||
|
# Delete checkpoint files.
|
||||||
|
for chk_dir in [LAST_DIR, BEST_DIR]:
|
||||||
|
for f in os.listdir(chk_dir):
|
||||||
|
os.remove(os.path.join(chk_dir, f))
|
||||||
else:
|
else:
|
||||||
print("No valid optimizer state found; using fresh optimizer state.")
|
print("No valid optimizer state found in checkpoint; using fresh optimizer state.")
|
||||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||||
|
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user