diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index c90e297..5a10fcc 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -162,51 +162,55 @@ def get_best_models(directory): for file in os.listdir(directory): parts = file.split("_") try: - r = float(parts[1]) - best_files.append((r, file)) + # parts[1] is the recorded loss + loss = float(parts[1]) + best_files.append((loss, file)) except Exception: continue return best_files -def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): +def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, - "reward": reward, + "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False + # Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss. if len(best_models) < 10: add_to_best = True else: - min_reward, min_file = min(best_models, key=lambda x: x[0]) - if reward > min_reward: + # The worst saved checkpoint will have the highest loss. + worst_loss, worst_file = max(best_models, key=lambda x: x[0]) + if loss < worst_loss: add_to_best = True - os.remove(os.path.join(best_dir, min_file)) + os.remove(os.path.join(best_dir, worst_file)) if add_to_best: - best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt" + best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, - "reward": reward, + "loss": loss, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) - print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") + print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None - best_reward, best_file = max(best_models, key=lambda x: x[0]) + # Choose the checkpoint with the lowest loss. + best_loss, best_file = min(best_models, key=lambda x: x[0]) path = os.path.join(best_dir, best_file) - print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}") + print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint @@ -305,7 +309,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s break state = next_state scheduler.step() - print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") + epoch_loss = total_loss/len(env) + print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") save_checkpoint(model, optimizer, epoch, total_loss) # --- Live Plotting Functions ---