fix models loading /saving issue
This commit is contained in:
@@ -63,6 +63,7 @@ class CheckpointManager:
|
||||
self.enable_wandb = False
|
||||
|
||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||
self._warned_models = set() # Track models we've warned about to reduce spam
|
||||
self._load_metadata()
|
||||
|
||||
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
|
||||
@@ -71,6 +72,7 @@ class CheckpointManager:
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
"""Save a model checkpoint with improved error handling and validation"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_id = f"{model_name}_{timestamp}"
|
||||
@@ -155,7 +157,11 @@ class CheckpointManager:
|
||||
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
|
||||
return str(legacy_model_path), legacy_metadata
|
||||
|
||||
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
|
||||
# Only warn once per model to avoid spam
|
||||
if model_name not in self._warned_models:
|
||||
logger.info(f"No checkpoints found for {model_name}, starting fresh")
|
||||
self._warned_models.add(model_name)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
@@ -327,15 +333,29 @@ class CheckpointManager:
|
||||
"""Find legacy saved models based on model name patterns"""
|
||||
base_dir = Path(self.base_dir)
|
||||
|
||||
# Additional search locations
|
||||
search_dirs = [
|
||||
base_dir,
|
||||
Path("models/saved"),
|
||||
Path("NN/models/saved"),
|
||||
Path("models"),
|
||||
Path("models/archive"),
|
||||
Path("models/backtest")
|
||||
]
|
||||
|
||||
# Define model name mappings and patterns for legacy files
|
||||
legacy_patterns = {
|
||||
'dqn_agent': [
|
||||
'dqn_agent_session_policy.pt',
|
||||
'dqn_agent_session_agent_state.pt',
|
||||
'dqn_agent_best_policy.pt',
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt',
|
||||
'dqn_agent_final_policy.pt'
|
||||
'dqn_agent_final_policy.pt',
|
||||
'trading_agent_best_pnl.pt'
|
||||
],
|
||||
'enhanced_cnn': [
|
||||
'cnn_model_session.pt',
|
||||
'cnn_model_best.pt',
|
||||
'optimized_short_term_model_best.pt',
|
||||
'optimized_short_term_model_realtime_best.pt',
|
||||
@@ -369,12 +389,16 @@ class CheckpointManager:
|
||||
f'{model_name}_final_policy.pt'
|
||||
])
|
||||
|
||||
# Search for the model files
|
||||
for pattern in patterns:
|
||||
candidate_path = base_dir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.debug(f"Found legacy model file: {candidate_path}")
|
||||
return candidate_path
|
||||
# Search for the model files in all search directories
|
||||
for search_dir in search_dirs:
|
||||
if not search_dir.exists():
|
||||
continue
|
||||
|
||||
for pattern in patterns:
|
||||
candidate_path = search_dir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.info(f"Found legacy model file: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Also check subdirectories
|
||||
for subdir in base_dir.iterdir():
|
||||
|
Reference in New Issue
Block a user