checkpoint manager
This commit is contained in:
@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
|
||||
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -78,7 +78,7 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
metadata = save_checkpoint(
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
@ -137,7 +137,7 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
metadata = save_checkpoint(
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_type='rl',
|
||||
@ -158,7 +158,7 @@ class TrainingIntegration:
|
||||
|
||||
def load_best_model(self, model_name: str, model_class=None):
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
result = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
if not result:
|
||||
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||
return None
|
||||
|
Reference in New Issue
Block a user