training resume works

This commit is contained in:
Dobromir Popov 2025-02-04 21:29:32 +02:00
parent ee10274586
commit c9f7367bcf

View File

@ -162,51 +162,55 @@ def get_best_models(directory):
for file in os.listdir(directory):
parts = file.split("_")
try:
r = float(parts[1])
best_files.append((r, file))
# parts[1] is the recorded loss
loss = float(parts[1])
best_files.append((loss, file))
except Exception:
continue
return best_files
def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
# The worst saved checkpoint will have the highest loss.
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
os.remove(os.path.join(best_dir, worst_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
# Choose the checkpoint with the lowest loss.
best_loss, best_file = min(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
@ -305,7 +309,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
break
state = next_state
scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
epoch_loss = total_loss/len(env)
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, total_loss)
# --- Live Plotting Functions ---