training resume works
This commit is contained in:
parent
ee10274586
commit
c9f7367bcf
@ -162,51 +162,55 @@ def get_best_models(directory):
|
||||
for file in os.listdir(directory):
|
||||
parts = file.split("_")
|
||||
try:
|
||||
r = float(parts[1])
|
||||
best_files.append((r, file))
|
||||
# parts[1] is the recorded loss
|
||||
loss = float(parts[1])
|
||||
best_files.append((loss, file))
|
||||
except Exception:
|
||||
continue
|
||||
return best_files
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||
last_path = os.path.join(last_dir, last_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"reward": reward,
|
||||
"loss": loss,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict()
|
||||
}, last_path)
|
||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||
best_models = get_best_models(best_dir)
|
||||
add_to_best = False
|
||||
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
|
||||
if len(best_models) < 10:
|
||||
add_to_best = True
|
||||
else:
|
||||
min_reward, min_file = min(best_models, key=lambda x: x[0])
|
||||
if reward > min_reward:
|
||||
# The worst saved checkpoint will have the highest loss.
|
||||
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
|
||||
if loss < worst_loss:
|
||||
add_to_best = True
|
||||
os.remove(os.path.join(best_dir, min_file))
|
||||
os.remove(os.path.join(best_dir, worst_file))
|
||||
if add_to_best:
|
||||
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||
best_path = os.path.join(best_dir, best_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"reward": reward,
|
||||
"loss": loss,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict()
|
||||
}, best_path)
|
||||
maintain_checkpoint_directory(best_dir, max_files=10)
|
||||
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
|
||||
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
|
||||
|
||||
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
best_models = get_best_models(best_dir)
|
||||
if not best_models:
|
||||
return None
|
||||
best_reward, best_file = max(best_models, key=lambda x: x[0])
|
||||
# Choose the checkpoint with the lowest loss.
|
||||
best_loss, best_file = min(best_models, key=lambda x: x[0])
|
||||
path = os.path.join(best_dir, best_file)
|
||||
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
|
||||
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
return checkpoint
|
||||
@ -305,7 +309,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
break
|
||||
state = next_state
|
||||
scheduler.step()
|
||||
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
|
||||
epoch_loss = total_loss/len(env)
|
||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
||||
save_checkpoint(model, optimizer, epoch, total_loss)
|
||||
|
||||
# --- Live Plotting Functions ---
|
||||
|
Loading…
x
Reference in New Issue
Block a user