""" Unified Model Management System for Trading Dashboard CONSOLIDATED SYSTEM - All model management functionality in one place This system provides: - Automatic cleanup of old model checkpoints - Best model tracking with performance metrics - Configurable retention policies - Startup model loading - Performance-based model selection - Robust model saving with multiple fallback strategies - Checkpoint management with W&B integration - Centralized storage using @checkpoints/ structure """ import os import json import shutil import logging import torch import glob import pickle import hashlib import random import numpy as np from pathlib import Path from datetime import datetime from dataclasses import dataclass, asdict from typing import Dict, Any, Optional, List, Tuple, Union from collections import defaultdict # W&B import (optional) try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False wandb = None logger = logging.getLogger(__name__) @dataclass class ModelMetrics: """Enhanced performance metrics for model evaluation""" accuracy: float = 0.0 profit_factor: float = 0.0 win_rate: float = 0.0 sharpe_ratio: float = 0.0 max_drawdown: float = 0.0 total_trades: int = 0 avg_trade_duration: float = 0.0 confidence_score: float = 0.0 # Additional metrics from checkpoint_manager loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None def get_composite_score(self) -> float: """Calculate composite performance score""" # Weighted composite score weights = { 'profit_factor': 0.25, 'sharpe_ratio': 0.2, 'win_rate': 0.15, 'accuracy': 0.15, 'confidence_score': 0.1, 'loss_penalty': 0.1, # New: penalize high loss 'val_penalty': 0.05 # New: penalize validation loss } # Normalize values to 0-1 range normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0 normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1 normalized_win_rate = self.win_rate normalized_accuracy = self.accuracy normalized_confidence = self.confidence_score # Loss penalty (lower loss = higher score) loss_penalty = 1.0 if self.loss is not None and self.loss > 0: loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty # Validation penalty val_penalty = 1.0 if self.val_loss is not None and self.val_loss > 0: val_penalty = max(0.1, 1 / (1 + self.val_loss)) # Apply penalties for poor performance drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown score = ( weights['profit_factor'] * normalized_pf + weights['sharpe_ratio'] * normalized_sharpe + weights['win_rate'] * normalized_win_rate + weights['accuracy'] * normalized_accuracy + weights['confidence_score'] * normalized_confidence + weights['loss_penalty'] * loss_penalty + weights['val_penalty'] * val_penalty ) * drawdown_penalty return min(max(score, 0), 1) @dataclass class ModelInfo: """Model information tracking""" model_type: str # 'cnn', 'rl', 'transformer' model_name: str file_path: str creation_time: datetime last_updated: datetime file_size_mb: float metrics: ModelMetrics training_episodes: int = 0 model_version: str = "1.0" def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for JSON serialization""" data = asdict(self) data['creation_time'] = self.creation_time.isoformat() data['last_updated'] = self.last_updated.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo': """Create from dictionary""" data['creation_time'] = datetime.fromisoformat(data['creation_time']) data['last_updated'] = datetime.fromisoformat(data['last_updated']) data['metrics'] = ModelMetrics(**data['metrics']) return cls(**data) @dataclass class CheckpointMetadata: checkpoint_id: str model_name: str model_type: str file_path: str created_at: datetime file_size_mb: float performance_score: float accuracy: Optional[float] = None loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None wandb_run_id: Optional[str] = None wandb_artifact_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: data = asdict(self) data['created_at'] = self.created_at.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': data['created_at'] = datetime.fromisoformat(data['created_at']) return cls(**data) class ModelManager: """Unified model management system with @checkpoints/ structure""" def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None): self.base_dir = Path(base_dir) self.config = config or self._get_default_config() # Updated directory structure using @checkpoints/ self.checkpoints_dir = self.base_dir / "@checkpoints" self.models_dir = self.checkpoints_dir / "models" self.saved_dir = self.checkpoints_dir / "saved" self.best_models_dir = self.checkpoints_dir / "best_models" self.archive_dir = self.checkpoints_dir / "archive" # Model type directories within @checkpoints/ self.model_dirs = { 'cnn': self.checkpoints_dir / "cnn", 'dqn': self.checkpoints_dir / "dqn", 'rl': self.checkpoints_dir / "rl", 'transformer': self.checkpoints_dir / "transformer", 'hybrid': self.checkpoints_dir / "hybrid" } # Legacy directories for backward compatibility self.nn_models_dir = self.base_dir / "NN" / "models" self.legacy_models_dir = self.base_dir / "models" # Legacy checkpoint directories (where existing checkpoints are stored) self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints" self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json" # Metadata and checkpoint management self.metadata_file = self.checkpoints_dir / "model_metadata.json" self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json" # Initialize storage self._initialize_directories() self.metadata = self._load_metadata() self.checkpoint_metadata = self._load_checkpoint_metadata() logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}") def _get_default_config(self) -> Dict[str, Any]: """Get default configuration""" return { 'max_checkpoints_per_model': 5, 'cleanup_old_models': True, 'auto_archive': True, 'wandb_enabled': WANDB_AVAILABLE, 'checkpoint_retention_days': 30 } def _initialize_directories(self): """Initialize directory structure""" directories = [ self.checkpoints_dir, self.models_dir, self.saved_dir, self.best_models_dir, self.archive_dir ] + list(self.model_dirs.values()) for directory in directories: directory.mkdir(parents=True, exist_ok=True) def _load_metadata(self) -> Dict[str, Any]: """Load model metadata with legacy support""" metadata = {'models': {}, 'last_updated': datetime.now().isoformat()} # First try to load from new unified metadata if self.metadata_file.exists(): try: with open(self.metadata_file, 'r') as f: metadata = json.load(f) logger.info(f"Loaded unified metadata from {self.metadata_file}") except Exception as e: logger.error(f"Error loading unified metadata: {e}") # Also load legacy metadata for backward compatibility if self.legacy_registry_file.exists(): try: with open(self.legacy_registry_file, 'r') as f: legacy_data = json.load(f) # Merge legacy data into unified metadata if 'models' in legacy_data: for model_name, model_info in legacy_data['models'].items(): if model_name not in metadata['models']: # Convert legacy path format to absolute path if 'latest_path' in model_info: legacy_path = model_info['latest_path'] # Handle different legacy path formats if not legacy_path.startswith('/'): # Try multiple path resolution strategies possible_paths = [ self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/... self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/... self.base_dir / legacy_path, # /project/models/cnn/... ] resolved_path = None for path in possible_paths: if path.exists(): resolved_path = path break if resolved_path: legacy_path = str(resolved_path) else: # If no resolved path found, try to find the file by pattern filename = Path(legacy_path).name for search_path in [self.legacy_checkpoints_dir]: for file_path in search_path.rglob(filename): legacy_path = str(file_path) break metadata['models'][model_name] = { 'type': model_info.get('type', 'unknown'), 'latest_path': legacy_path, 'last_saved': model_info.get('last_saved', 'legacy'), 'save_count': model_info.get('save_count', 1), 'checkpoints': model_info.get('checkpoints', []) } logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}") logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}") except Exception as e: logger.error(f"Error loading legacy metadata: {e}") return metadata def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]: """Load checkpoint metadata""" if self.checkpoint_metadata_file.exists(): try: with open(self.checkpoint_metadata_file, 'r') as f: data = json.load(f) # Convert dict values back to CheckpointMetadata objects result = {} for key, checkpoints in data.items(): result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints] return result except Exception as e: logger.error(f"Error loading checkpoint metadata: {e}") return defaultdict(list) def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: """Save a model checkpoint with enhanced error handling and validation""" try: performance_score = self._calculate_performance_score(performance_metrics) if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") return None # Create checkpoint directory checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) # Generate checkpoint filename timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') checkpoint_id = f"{model_name}_{timestamp}" filename = f"{checkpoint_id}.pt" filepath = checkpoint_dir / filename # Save model save_dict = { 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {}, 'model_class': model.__class__.__name__, 'checkpoint_id': checkpoint_id, 'model_name': model_name, 'model_type': model_type, 'performance_score': performance_score, 'performance_metrics': performance_metrics, 'training_metadata': training_metadata or {}, 'created_at': datetime.now().isoformat(), 'version': '2.0' } torch.save(save_dict, filepath) # Create checkpoint metadata file_size_mb = filepath.stat().st_size / (1024 * 1024) metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=model_name, model_type=model_type, file_path=str(filepath), created_at=datetime.now(), file_size_mb=file_size_mb, performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), val_accuracy=performance_metrics.get('val_accuracy'), val_loss=performance_metrics.get('val_loss'), reward=performance_metrics.get('reward'), pnl=performance_metrics.get('pnl'), epoch=performance_metrics.get('epoch'), training_time_hours=performance_metrics.get('training_time_hours'), total_parameters=performance_metrics.get('total_parameters') ) # Store metadata self.checkpoint_metadata[model_name].append(metadata) self._save_checkpoint_metadata() # Rotate checkpoints if needed self._rotate_checkpoints(model_name) # Upload to W&B if enabled if self.config.get('wandb_enabled'): self._upload_to_wandb(metadata) logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})") return metadata except Exception as e: logger.error(f"Error saving checkpoint for {model_name}: {e}") return None def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: """Calculate performance score from metrics""" # Simple weighted score - can be enhanced weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1} score = 0.0 for metric, weight in weights.items(): if metric in metrics: score += metrics[metric] * weight return score def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: """Determine if checkpoint should be saved""" existing_checkpoints = self.checkpoint_metadata.get(model_name, []) if not existing_checkpoints: return True # Keep if better than worst checkpoint or if we have fewer than max max_checkpoints = self.config.get('max_checkpoints_per_model', 5) if len(existing_checkpoints) < max_checkpoints: return True worst_score = min(cp.performance_score for cp in existing_checkpoints) return performance_score > worst_score def _rotate_checkpoints(self, model_name: str): """Rotate checkpoints to maintain max count""" checkpoints = self.checkpoint_metadata.get(model_name, []) max_checkpoints = self.config.get('max_checkpoints_per_model', 5) if len(checkpoints) <= max_checkpoints: return # Sort by performance score (descending) checkpoints.sort(key=lambda x: x.performance_score, reverse=True) # Remove excess checkpoints to_remove = checkpoints[max_checkpoints:] for checkpoint in to_remove: try: Path(checkpoint.file_path).unlink(missing_ok=True) logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}") except Exception as e: logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}") # Update metadata self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints] self._save_checkpoint_metadata() def _save_checkpoint_metadata(self): """Save checkpoint metadata to file""" try: data = {} for model_name, checkpoints in self.checkpoint_metadata.items(): data[model_name] = [cp.to_dict() for cp in checkpoints] with open(self.checkpoint_metadata_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}") def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]: """Upload checkpoint to W&B""" if not WANDB_AVAILABLE: return None try: # This would be implemented based on your W&B workflow logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}") return None except Exception as e: logger.error(f"Error uploading to W&B: {e}") return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: """Load the best checkpoint for a model with legacy support""" try: # First, try the unified registry model_info = self.metadata['models'].get(model_name) if model_info and Path(model_info['latest_path']).exists(): logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}") # Create metadata from model info for compatibility registry_metadata = CheckpointMetadata( checkpoint_id=f"{model_name}_registry", model_name=model_name, model_type=model_info.get('type', model_name), file_path=model_info['latest_path'], created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())), file_size_mb=0.0, # Will be calculated if needed performance_score=0.0, # Unknown from registry accuracy=None, loss=None, # Orchestrator will handle this val_accuracy=None, val_loss=None ) return model_info['latest_path'], registry_metadata # Fallback to checkpoint metadata checkpoints = self.checkpoint_metadata.get(model_name, []) if checkpoints: # Get best checkpoint best_checkpoint = max(checkpoints, key=lambda x: x.performance_score) if Path(best_checkpoint.file_path).exists(): logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}") return best_checkpoint.file_path, best_checkpoint # Legacy fallback: Look for checkpoints in legacy directories logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}") legacy_path = self._find_legacy_checkpoint(model_name) if legacy_path: logger.info(f"Found legacy checkpoint: {legacy_path}") # Create a basic CheckpointMetadata for the legacy checkpoint legacy_metadata = CheckpointMetadata( checkpoint_id=f"legacy_{model_name}", model_name=model_name, model_type=model_name, # Will be inferred from model type file_path=str(legacy_path), created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime), file_size_mb=legacy_path.stat().st_size / (1024 * 1024), performance_score=0.0, # Unknown for legacy accuracy=None, loss=None ) return str(legacy_path), legacy_metadata logger.warning(f"No checkpoints found for {model_name} in any location") return None except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]: """Find checkpoint in legacy directories""" if not self.legacy_checkpoints_dir.exists(): return None # Use unified model naming throughout the project # All model references use consistent short names: dqn, cnn, cob_rl, transformer, decision # This eliminates complex mapping and ensures consistency across the entire codebase patterns = [model_name] # Add minimal backward compatibility patterns if model_name == 'dqn': patterns.extend(['dqn_agent', 'agent']) elif model_name == 'cnn': patterns.extend(['cnn_model', 'enhanced_cnn']) elif model_name == 'cob_rl': patterns.extend(['rl', 'rl_agent', 'trading_agent']) # Search in legacy saved directory first legacy_saved_dir = self.legacy_checkpoints_dir / "saved" if legacy_saved_dir.exists(): for file_path in legacy_saved_dir.rglob("*.pt"): filename = file_path.name.lower() if any(pattern in filename for pattern in patterns): return file_path # Search in model-specific directories for model_type in ['cnn', 'dqn', 'rl', 'transformer', 'decision']: model_dir = self.legacy_checkpoints_dir / model_type if model_dir.exists(): saved_dir = model_dir / "saved" if saved_dir.exists(): for file_path in saved_dir.rglob("*.pt"): filename = file_path.name.lower() if any(pattern in filename for pattern in patterns): return file_path # Search in archive directory archive_dir = self.legacy_checkpoints_dir / "archive" if archive_dir.exists(): for file_path in archive_dir.rglob("*.pt"): filename = file_path.name.lower() if any(pattern in filename for pattern in patterns): return file_path # Search in backtest directory (might contain RL or other models) backtest_dir = self.legacy_checkpoints_dir / "backtest" if backtest_dir.exists(): for file_path in backtest_dir.rglob("*.pt"): filename = file_path.name.lower() if any(pattern in filename for pattern in patterns): return file_path # Last resort: search entire legacy directory for file_path in self.legacy_checkpoints_dir.rglob("*.pt"): filename = file_path.name.lower() if any(pattern in filename for pattern in patterns): return file_path return None def get_storage_stats(self) -> Dict[str, Any]: """Get storage statistics""" try: total_size = 0 file_count = 0 for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]: if directory.exists(): for file_path in directory.rglob('*'): if file_path.is_file(): total_size += file_path.stat().st_size file_count += 1 return { 'total_size_mb': total_size / (1024 * 1024), 'file_count': file_count, 'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0 } except Exception as e: logger.error(f"Error getting storage stats: {e}") return {'error': str(e)} def get_checkpoint_stats(self) -> Dict[str, Any]: """Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)""" try: stats = { 'total_models': 0, 'total_checkpoints': 0, 'total_size_mb': 0.0, 'models': {} } # Count files in new unified directories checkpoint_dirs = [ self.checkpoints_dir / "cnn", self.checkpoints_dir / "dqn", self.checkpoints_dir / "rl", self.checkpoints_dir / "transformer", self.checkpoints_dir / "hybrid" ] total_size = 0 total_files = 0 for checkpoint_dir in checkpoint_dirs: if checkpoint_dir.exists(): model_files = list(checkpoint_dir.rglob('*.pt')) if model_files: model_name = checkpoint_dir.name stats['total_models'] += 1 model_size = sum(f.stat().st_size for f in model_files) stats['total_checkpoints'] += len(model_files) stats['total_size_mb'] += model_size / (1024 * 1024) total_size += model_size total_files += len(model_files) # Get the most recent file as "latest" latest_file = max(model_files, key=lambda f: f.stat().st_mtime) stats['models'][model_name] = { 'checkpoint_count': len(model_files), 'total_size_mb': model_size / (1024 * 1024), 'best_performance': 0.0, # Not tracked in unified system 'best_checkpoint_id': latest_file.name, 'latest_checkpoint': latest_file.name } # Also check saved models directory if self.saved_dir.exists(): saved_files = list(self.saved_dir.rglob('*.pt')) if saved_files: stats['total_checkpoints'] += len(saved_files) saved_size = sum(f.stat().st_size for f in saved_files) stats['total_size_mb'] += saved_size / (1024 * 1024) # Add legacy checkpoint statistics if self.legacy_checkpoints_dir.exists(): legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt')) if legacy_files: legacy_size = sum(f.stat().st_size for f in legacy_files) stats['total_checkpoints'] += len(legacy_files) stats['total_size_mb'] += legacy_size / (1024 * 1024) # Add legacy models to stats legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision'] for model_dir_name in legacy_model_dirs: model_dir = self.legacy_checkpoints_dir / model_dir_name if model_dir.exists(): model_files = list(model_dir.rglob('*.pt')) if model_files and model_dir_name not in stats['models']: stats['total_models'] += 1 model_size = sum(f.stat().st_size for f in model_files) latest_file = max(model_files, key=lambda f: f.stat().st_mtime) stats['models'][model_dir_name] = { 'checkpoint_count': len(model_files), 'total_size_mb': model_size / (1024 * 1024), 'best_performance': 0.0, 'best_checkpoint_id': latest_file.name, 'latest_checkpoint': latest_file.name, 'location': 'legacy' } return stats except Exception as e: logger.error(f"Error getting checkpoint stats: {e}") return { 'total_models': 0, 'total_checkpoints': 0, 'total_size_mb': 0.0, 'models': {}, 'error': str(e) } def get_model_leaderboard(self) -> List[Dict[str, Any]]: """Get model performance leaderboard""" try: leaderboard = [] for model_name, model_info in self.metadata['models'].items(): if 'metrics' in model_info: metrics = ModelMetrics(**model_info['metrics']) leaderboard.append({ 'model_name': model_name, 'model_type': model_info.get('model_type', 'unknown'), 'composite_score': metrics.get_composite_score(), 'accuracy': metrics.accuracy, 'profit_factor': metrics.profit_factor, 'win_rate': metrics.win_rate, 'last_updated': model_info.get('last_saved', 'unknown') }) # Sort by composite score leaderboard.sort(key=lambda x: x['composite_score'], reverse=True) return leaderboard except Exception as e: logger.error(f"Error getting leaderboard: {e}") return [] # ===== LEGACY COMPATIBILITY FUNCTIONS ===== def create_model_manager() -> ModelManager: """Create and return a ModelManager instance""" return ModelManager() def save_model(model: Any, model_name: str, model_type: str = 'cnn', metadata: Optional[Dict[str, Any]] = None) -> bool: """Legacy compatibility function to save a model""" manager = create_model_manager() return manager.save_model(model, model_name, model_type, metadata) def load_model(model_name: str, model_type: str = 'cnn', model_class: Optional[Any] = None) -> Optional[Any]: """Legacy compatibility function to load a model""" manager = create_model_manager() return manager.load_model(model_name, model_type, model_class) def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: """Legacy compatibility function to save a checkpoint""" manager = create_model_manager() return manager.save_checkpoint(model, model_name, model_type, performance_metrics, training_metadata, force_save) def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: """Legacy compatibility function to load the best checkpoint""" manager = create_model_manager() return manager.load_best_checkpoint(model_name) # ===== EXAMPLE USAGE ===== if __name__ == "__main__": # Example usage of the unified model manager manager = create_model_manager() print(f"ModelManager initialized at: {manager.checkpoints_dir}") # Get storage stats stats = manager.get_storage_stats() print(f"Storage stats: {stats}") # Get leaderboard leaderboard = manager.get_model_leaderboard() print(f"Models in leaderboard: {len(leaderboard)}")