refactoring
This commit is contained in:
@@ -20,8 +20,8 @@ import torch.nn.functional as F
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
|
||||
from NN.training.model_manager import create_model_manager
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -778,7 +778,7 @@ class CNNModelTrainer:
|
||||
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
|
||||
"""Save model with metadata using unified registry"""
|
||||
try:
|
||||
from utils.model_registry import save_model
|
||||
from NN.training.model_manager import save_model
|
||||
|
||||
# Prepare model data
|
||||
model_data = {
|
||||
@@ -826,7 +826,7 @@ class CNNModelTrainer:
|
||||
def load_model(self, filepath: str = None) -> Dict:
|
||||
"""Load model from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
from NN.training.model_manager import load_model
|
||||
|
||||
# Use unified registry if no filepath or if it's a models/ path
|
||||
if filepath is None or filepath.startswith('models/'):
|
||||
|
Reference in New Issue
Block a user