refactoring
This commit is contained in:
@@ -15,8 +15,8 @@ import time
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
|
||||
from NN.training.model_manager import create_model_manager
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1333,7 +1333,7 @@ class DQNAgent:
|
||||
def save(self, path: str = None):
|
||||
"""Save model and agent state using unified registry"""
|
||||
try:
|
||||
from utils.model_registry import save_model
|
||||
from NN.training.model_manager import save_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
@@ -1393,7 +1393,7 @@ class DQNAgent:
|
||||
def load(self, path: str = None):
|
||||
"""Load model and agent state from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
from NN.training.model_manager import load_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
|
||||
Reference in New Issue
Block a user