refactoring
This commit is contained in:
@@ -14,7 +14,7 @@ from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import torch
|
||||
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
|
||||
from NN.training.model_manager import create_model_manager, CheckpointMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
class CheckpointCleanup:
|
||||
def __init__(self):
|
||||
self.saved_models_dir = Path("NN/models/saved")
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.checkpoint_manager = create_model_manager()
|
||||
|
||||
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
|
||||
logger.info("Analyzing existing checkpoint files...")
|
||||
|
||||
Reference in New Issue
Block a user