training resume works
This commit is contained in:
parent
ee10274586
commit
c9f7367bcf
@ -162,51 +162,55 @@ def get_best_models(directory):
|
|||||||
for file in os.listdir(directory):
|
for file in os.listdir(directory):
|
||||||
parts = file.split("_")
|
parts = file.split("_")
|
||||||
try:
|
try:
|
||||||
r = float(parts[1])
|
# parts[1] is the recorded loss
|
||||||
best_files.append((r, file))
|
loss = float(parts[1])
|
||||||
|
best_files.append((loss, file))
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
return best_files
|
return best_files
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||||
last_path = os.path.join(last_dir, last_filename)
|
last_path = os.path.join(last_dir, last_filename)
|
||||||
torch.save({
|
torch.save({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"reward": reward,
|
"loss": loss,
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict()
|
"optimizer_state_dict": optimizer.state_dict()
|
||||||
}, last_path)
|
}, last_path)
|
||||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||||
best_models = get_best_models(best_dir)
|
best_models = get_best_models(best_dir)
|
||||||
add_to_best = False
|
add_to_best = False
|
||||||
|
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
|
||||||
if len(best_models) < 10:
|
if len(best_models) < 10:
|
||||||
add_to_best = True
|
add_to_best = True
|
||||||
else:
|
else:
|
||||||
min_reward, min_file = min(best_models, key=lambda x: x[0])
|
# The worst saved checkpoint will have the highest loss.
|
||||||
if reward > min_reward:
|
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
|
||||||
|
if loss < worst_loss:
|
||||||
add_to_best = True
|
add_to_best = True
|
||||||
os.remove(os.path.join(best_dir, min_file))
|
os.remove(os.path.join(best_dir, worst_file))
|
||||||
if add_to_best:
|
if add_to_best:
|
||||||
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
|
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||||
best_path = os.path.join(best_dir, best_filename)
|
best_path = os.path.join(best_dir, best_filename)
|
||||||
torch.save({
|
torch.save({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"reward": reward,
|
"loss": loss,
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict()
|
"optimizer_state_dict": optimizer.state_dict()
|
||||||
}, best_path)
|
}, best_path)
|
||||||
maintain_checkpoint_directory(best_dir, max_files=10)
|
maintain_checkpoint_directory(best_dir, max_files=10)
|
||||||
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
|
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
|
||||||
|
|
||||||
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||||
best_models = get_best_models(best_dir)
|
best_models = get_best_models(best_dir)
|
||||||
if not best_models:
|
if not best_models:
|
||||||
return None
|
return None
|
||||||
best_reward, best_file = max(best_models, key=lambda x: x[0])
|
# Choose the checkpoint with the lowest loss.
|
||||||
|
best_loss, best_file = min(best_models, key=lambda x: x[0])
|
||||||
path = os.path.join(best_dir, best_file)
|
path = os.path.join(best_dir, best_file)
|
||||||
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
|
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||||
checkpoint = torch.load(path)
|
checkpoint = torch.load(path)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
return checkpoint
|
return checkpoint
|
||||||
@ -305,7 +309,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
|||||||
break
|
break
|
||||||
state = next_state
|
state = next_state
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
|
epoch_loss = total_loss/len(env)
|
||||||
|
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
||||||
save_checkpoint(model, optimizer, epoch, total_loss)
|
save_checkpoint(model, optimizer, epoch, total_loss)
|
||||||
|
|
||||||
# --- Live Plotting Functions ---
|
# --- Live Plotting Functions ---
|
||||||
|
Loading…
x
Reference in New Issue
Block a user