fix models loading /saving issue

This commit is contained in:
Dobromir Popov
2025-09-02 16:05:44 +03:00
parent 1b54438082
commit 15cc694669
13 changed files with 2264 additions and 72 deletions

View File

@@ -63,6 +63,7 @@ class CheckpointManager:
self.enable_wandb = False
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._warned_models = set() # Track models we've warned about to reduce spam
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
@@ -71,6 +72,7 @@ class CheckpointManager:
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with improved error handling and validation"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
@@ -155,7 +157,11 @@ class CheckpointManager:
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
# Only warn once per model to avoid spam
if model_name not in self._warned_models:
logger.info(f"No checkpoints found for {model_name}, starting fresh")
self._warned_models.add(model_name)
return None
except Exception as e:
@@ -327,15 +333,29 @@ class CheckpointManager:
"""Find legacy saved models based on model name patterns"""
base_dir = Path(self.base_dir)
# Additional search locations
search_dirs = [
base_dir,
Path("models/saved"),
Path("NN/models/saved"),
Path("models"),
Path("models/archive"),
Path("models/backtest")
]
# Define model name mappings and patterns for legacy files
legacy_patterns = {
'dqn_agent': [
'dqn_agent_session_policy.pt',
'dqn_agent_session_agent_state.pt',
'dqn_agent_best_policy.pt',
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt',
'dqn_agent_final_policy.pt'
'dqn_agent_final_policy.pt',
'trading_agent_best_pnl.pt'
],
'enhanced_cnn': [
'cnn_model_session.pt',
'cnn_model_best.pt',
'optimized_short_term_model_best.pt',
'optimized_short_term_model_realtime_best.pt',
@@ -369,12 +389,16 @@ class CheckpointManager:
f'{model_name}_final_policy.pt'
])
# Search for the model files
for pattern in patterns:
candidate_path = base_dir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file: {candidate_path}")
return candidate_path
# Search for the model files in all search directories
for search_dir in search_dirs:
if not search_dir.exists():
continue
for pattern in patterns:
candidate_path = search_dir / pattern
if candidate_path.exists():
logger.info(f"Found legacy model file: {candidate_path}")
return candidate_path
# Also check subdirectories
for subdir in base_dir.iterdir():