diff --git a/NN/training/model_manager.py b/NN/training/model_manager.py index 29fd902..bebff61 100644 --- a/NN/training/model_manager.py +++ b/NN/training/model_manager.py @@ -197,6 +197,10 @@ class ModelManager: self.nn_models_dir = self.base_dir / "NN" / "models" self.legacy_models_dir = self.base_dir / "models" + # Legacy checkpoint directories (where existing checkpoints are stored) + self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints" + self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json" + # Metadata and checkpoint management self.metadata_file = self.checkpoints_dir / "model_metadata.json" self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json" @@ -232,14 +236,72 @@ class ModelManager: directory.mkdir(parents=True, exist_ok=True) def _load_metadata(self) -> Dict[str, Any]: - """Load model metadata""" + """Load model metadata with legacy support""" + metadata = {'models': {}, 'last_updated': datetime.now().isoformat()} + + # First try to load from new unified metadata if self.metadata_file.exists(): try: with open(self.metadata_file, 'r') as f: - return json.load(f) + metadata = json.load(f) + logger.info(f"Loaded unified metadata from {self.metadata_file}") except Exception as e: - logger.error(f"Error loading metadata: {e}") - return {'models': {}, 'last_updated': datetime.now().isoformat()} + logger.error(f"Error loading unified metadata: {e}") + + # Also load legacy metadata for backward compatibility + if self.legacy_registry_file.exists(): + try: + with open(self.legacy_registry_file, 'r') as f: + legacy_data = json.load(f) + + # Merge legacy data into unified metadata + if 'models' in legacy_data: + for model_name, model_info in legacy_data['models'].items(): + if model_name not in metadata['models']: + # Convert legacy path format to absolute path + if 'latest_path' in model_info: + legacy_path = model_info['latest_path'] + + # Handle different legacy path formats + if not legacy_path.startswith('/'): + # Try multiple path resolution strategies + possible_paths = [ + self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/... + self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/... + self.base_dir / legacy_path, # /project/models/cnn/... + ] + + resolved_path = None + for path in possible_paths: + if path.exists(): + resolved_path = path + break + + if resolved_path: + legacy_path = str(resolved_path) + else: + # If no resolved path found, try to find the file by pattern + filename = Path(legacy_path).name + for search_path in [self.legacy_checkpoints_dir]: + for file_path in search_path.rglob(filename): + legacy_path = str(file_path) + break + + metadata['models'][model_name] = { + 'type': model_info.get('type', 'unknown'), + 'latest_path': legacy_path, + 'last_saved': model_info.get('last_saved', 'legacy'), + 'save_count': model_info.get('save_count', 1), + 'checkpoints': model_info.get('checkpoints', []) + } + logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}") + + logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}") + + except Exception as e: + logger.error(f"Error loading legacy metadata: {e}") + + return metadata def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]: """Load checkpoint metadata""" @@ -407,34 +469,129 @@ class ModelManager: return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: - """Load the best checkpoint for a model""" + """Load the best checkpoint for a model with legacy support""" try: # First, try the unified registry model_info = self.metadata['models'].get(model_name) if model_info and Path(model_info['latest_path']).exists(): - # Load from unified registry - load_dict = torch.load(model_info['latest_path'], map_location='cpu') - return model_info['latest_path'], None + logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}") + # Create metadata from model info for compatibility + registry_metadata = CheckpointMetadata( + checkpoint_id=f"{model_name}_registry", + model_name=model_name, + model_type=model_info.get('type', model_name), + file_path=model_info['latest_path'], + created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())), + file_size_mb=0.0, # Will be calculated if needed + performance_score=0.0, # Unknown from registry + accuracy=None, + loss=None, # Orchestrator will handle this + val_accuracy=None, + val_loss=None + ) + return model_info['latest_path'], registry_metadata # Fallback to checkpoint metadata checkpoints = self.checkpoint_metadata.get(model_name, []) - if not checkpoints: - logger.warning(f"No checkpoints found for {model_name}") - return None + if checkpoints: + # Get best checkpoint + best_checkpoint = max(checkpoints, key=lambda x: x.performance_score) - # Get best checkpoint - best_checkpoint = max(checkpoints, key=lambda x: x.performance_score) + if Path(best_checkpoint.file_path).exists(): + logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}") + return best_checkpoint.file_path, best_checkpoint - if not Path(best_checkpoint.file_path).exists(): - logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}") - return None + # Legacy fallback: Look for checkpoints in legacy directories + logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}") + legacy_path = self._find_legacy_checkpoint(model_name) + if legacy_path: + logger.info(f"Found legacy checkpoint: {legacy_path}") + # Create a basic CheckpointMetadata for the legacy checkpoint + legacy_metadata = CheckpointMetadata( + checkpoint_id=f"legacy_{model_name}", + model_name=model_name, + model_type=model_name, # Will be inferred from model type + file_path=str(legacy_path), + created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime), + file_size_mb=legacy_path.stat().st_size / (1024 * 1024), + performance_score=0.0, # Unknown for legacy + accuracy=None, + loss=None + ) + return str(legacy_path), legacy_metadata - return best_checkpoint.file_path, best_checkpoint + logger.warning(f"No checkpoints found for {model_name} in any location") + return None except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None + def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]: + """Find checkpoint in legacy directories""" + if not self.legacy_checkpoints_dir.exists(): + return None + + # Define search patterns for different model types + # Handle both orchestrator naming and direct model naming + model_patterns = { + 'dqn_agent': ['dqn_agent', 'dqn', 'agent'], + 'enhanced_cnn': ['cnn_model', 'enhanced_cnn', 'cnn', 'optimized_short_term'], + 'cob_rl': ['cob_rl', 'rl', 'rl_agent', 'trading_agent'], + 'transformer': ['transformer', 'decision'], + 'decision': ['decision', 'transformer'], + # Also support direct model names + 'dqn': ['dqn_agent', 'dqn', 'agent'], + 'cnn': ['cnn_model', 'cnn', 'optimized_short_term'], + 'rl': ['cob_rl', 'rl', 'rl_agent'] + } + + # Get patterns for this model name or use generic patterns + patterns = model_patterns.get(model_name, [model_name]) + + # Search in legacy saved directory first + legacy_saved_dir = self.legacy_checkpoints_dir / "saved" + if legacy_saved_dir.exists(): + for file_path in legacy_saved_dir.rglob("*.pt"): + filename = file_path.name.lower() + if any(pattern in filename for pattern in patterns): + return file_path + + # Search in model-specific directories + for model_type in model_patterns.keys(): + model_dir = self.legacy_checkpoints_dir / model_type + if model_dir.exists(): + saved_dir = model_dir / "saved" + if saved_dir.exists(): + for file_path in saved_dir.rglob("*.pt"): + filename = file_path.name.lower() + if any(pattern in filename for pattern in patterns): + return file_path + + # Search in archive directory + archive_dir = self.legacy_checkpoints_dir / "archive" + if archive_dir.exists(): + for file_path in archive_dir.rglob("*.pt"): + filename = file_path.name.lower() + if any(pattern in filename for pattern in patterns): + return file_path + + # Search in backtest directory (might contain RL or other models) + backtest_dir = self.legacy_checkpoints_dir / "backtest" + if backtest_dir.exists(): + for file_path in backtest_dir.rglob("*.pt"): + filename = file_path.name.lower() + if any(pattern in filename for pattern in patterns): + return file_path + + # Last resort: search entire legacy directory + for file_path in self.legacy_checkpoints_dir.rglob("*.pt"): + filename = file_path.name.lower() + if any(pattern in filename for pattern in patterns): + return file_path + + return None + def get_storage_stats(self) -> Dict[str, Any]: """Get storage statistics""" try: @@ -467,7 +624,7 @@ class ModelManager: 'models': {} } - # Count files in different directories as "checkpoints" + # Count files in new unified directories checkpoint_dirs = [ self.checkpoints_dir / "cnn", self.checkpoints_dir / "dqn", @@ -511,6 +668,34 @@ class ModelManager: saved_size = sum(f.stat().st_size for f in saved_files) stats['total_size_mb'] += saved_size / (1024 * 1024) + # Add legacy checkpoint statistics + if self.legacy_checkpoints_dir.exists(): + legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt')) + if legacy_files: + legacy_size = sum(f.stat().st_size for f in legacy_files) + stats['total_checkpoints'] += len(legacy_files) + stats['total_size_mb'] += legacy_size / (1024 * 1024) + + # Add legacy models to stats + legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision'] + for model_dir_name in legacy_model_dirs: + model_dir = self.legacy_checkpoints_dir / model_dir_name + if model_dir.exists(): + model_files = list(model_dir.rglob('*.pt')) + if model_files and model_dir_name not in stats['models']: + stats['total_models'] += 1 + model_size = sum(f.stat().st_size for f in model_files) + latest_file = max(model_files, key=lambda f: f.stat().st_mtime) + + stats['models'][model_dir_name] = { + 'checkpoint_count': len(model_files), + 'total_size_mb': model_size / (1024 * 1024), + 'best_performance': 0.0, + 'best_checkpoint_id': latest_file.name, + 'latest_checkpoint': latest_file.name, + 'location': 'legacy' + } + return stats except Exception as e: