training resume works

This commit is contained in:
Dobromir Popov 2025-02-04 21:29:32 +02:00
parent ee10274586
commit c9f7367bcf

View File

@ -162,51 +162,55 @@ def get_best_models(directory):
for file in os.listdir(directory): for file in os.listdir(directory):
parts = file.split("_") parts = file.split("_")
try: try:
r = float(parts[1]) # parts[1] is the recorded loss
best_files.append((r, file)) loss = float(parts[1])
best_files.append((loss, file))
except Exception: except Exception:
continue continue
return best_files return best_files
def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename) last_path = os.path.join(last_dir, last_filename)
torch.save({ torch.save({
"epoch": epoch, "epoch": epoch,
"reward": reward, "loss": loss,
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict() "optimizer_state_dict": optimizer.state_dict()
}, last_path) }, last_path)
maintain_checkpoint_directory(last_dir, max_files=10) maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir) best_models = get_best_models(best_dir)
add_to_best = False add_to_best = False
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
if len(best_models) < 10: if len(best_models) < 10:
add_to_best = True add_to_best = True
else: else:
min_reward, min_file = min(best_models, key=lambda x: x[0]) # The worst saved checkpoint will have the highest loss.
if reward > min_reward: worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss:
add_to_best = True add_to_best = True
os.remove(os.path.join(best_dir, min_file)) os.remove(os.path.join(best_dir, worst_file))
if add_to_best: if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt" best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename) best_path = os.path.join(best_dir, best_filename)
torch.save({ torch.save({
"epoch": epoch, "epoch": epoch,
"reward": reward, "loss": loss,
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict() "optimizer_state_dict": optimizer.state_dict()
}, best_path) }, best_path)
maintain_checkpoint_directory(best_dir, max_files=10) maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR): def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir) best_models = get_best_models(best_dir)
if not best_models: if not best_models:
return None return None
best_reward, best_file = max(best_models, key=lambda x: x[0]) # Choose the checkpoint with the lowest loss.
best_loss, best_file = min(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file) path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}") print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path) checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint return checkpoint
@ -305,7 +309,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
break break
state = next_state state = next_state
scheduler.step() scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") epoch_loss = total_loss/len(env)
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, total_loss) save_checkpoint(model, optimizer, epoch, total_loss)
# --- Live Plotting Functions --- # --- Live Plotting Functions ---