""" Unified Model Management System for Trading Dashboard CONSOLIDATED SYSTEM - All model management functionality in one place This system provides: - Automatic cleanup of old model checkpoints - Best model tracking with performance metrics - Configurable retention policies - Startup model loading - Performance-based model selection - Robust model saving with multiple fallback strategies - Checkpoint management with W&B integration - Centralized storage using @checkpoints/ structure """ import os import json import shutil import logging import torch import glob import pickle import hashlib import random import numpy as np from pathlib import Path from datetime import datetime from dataclasses import dataclass, asdict from typing import Dict, Any, Optional, List, Tuple, Union from collections import defaultdict # W&B import (optional) try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False wandb = None logger = logging.getLogger(__name__) @dataclass class ModelMetrics: """Enhanced performance metrics for model evaluation""" accuracy: float = 0.0 profit_factor: float = 0.0 win_rate: float = 0.0 sharpe_ratio: float = 0.0 max_drawdown: float = 0.0 total_trades: int = 0 avg_trade_duration: float = 0.0 confidence_score: float = 0.0 # Additional metrics from checkpoint_manager loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None def get_composite_score(self) -> float: """Calculate composite performance score""" # Weighted composite score weights = { 'profit_factor': 0.25, 'sharpe_ratio': 0.2, 'win_rate': 0.15, 'accuracy': 0.15, 'confidence_score': 0.1, 'loss_penalty': 0.1, # New: penalize high loss 'val_penalty': 0.05 # New: penalize validation loss } # Normalize values to 0-1 range normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0 normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1 normalized_win_rate = self.win_rate normalized_accuracy = self.accuracy normalized_confidence = self.confidence_score # Loss penalty (lower loss = higher score) loss_penalty = 1.0 if self.loss is not None and self.loss > 0: loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty # Validation penalty val_penalty = 1.0 if self.val_loss is not None and self.val_loss > 0: val_penalty = max(0.1, 1 / (1 + self.val_loss)) # Apply penalties for poor performance drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown score = ( weights['profit_factor'] * normalized_pf + weights['sharpe_ratio'] * normalized_sharpe + weights['win_rate'] * normalized_win_rate + weights['accuracy'] * normalized_accuracy + weights['confidence_score'] * normalized_confidence + weights['loss_penalty'] * loss_penalty + weights['val_penalty'] * val_penalty ) * drawdown_penalty return min(max(score, 0), 1) @dataclass class ModelInfo: """Model information tracking""" model_type: str # 'cnn', 'rl', 'transformer' model_name: str file_path: str creation_time: datetime last_updated: datetime file_size_mb: float metrics: ModelMetrics training_episodes: int = 0 model_version: str = "1.0" def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for JSON serialization""" data = asdict(self) data['creation_time'] = self.creation_time.isoformat() data['last_updated'] = self.last_updated.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo': """Create from dictionary""" data['creation_time'] = datetime.fromisoformat(data['creation_time']) data['last_updated'] = datetime.fromisoformat(data['last_updated']) data['metrics'] = ModelMetrics(**data['metrics']) return cls(**data) @dataclass class CheckpointMetadata: checkpoint_id: str model_name: str model_type: str file_path: str created_at: datetime file_size_mb: float performance_score: float accuracy: Optional[float] = None loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None wandb_run_id: Optional[str] = None wandb_artifact_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: data = asdict(self) data['created_at'] = self.created_at.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': data['created_at'] = datetime.fromisoformat(data['created_at']) return cls(**data) class ModelManager: """Unified model management system with @checkpoints/ structure""" def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None): self.base_dir = Path(base_dir) self.config = config or self._get_default_config() # Updated directory structure using @checkpoints/ self.checkpoints_dir = self.base_dir / "@checkpoints" self.models_dir = self.checkpoints_dir / "models" self.saved_dir = self.checkpoints_dir / "saved" self.best_models_dir = self.checkpoints_dir / "best_models" self.archive_dir = self.checkpoints_dir / "archive" # Model type directories within @checkpoints/ self.model_dirs = { 'cnn': self.checkpoints_dir / "cnn", 'dqn': self.checkpoints_dir / "dqn", 'rl': self.checkpoints_dir / "rl", 'transformer': self.checkpoints_dir / "transformer", 'hybrid': self.checkpoints_dir / "hybrid" } # Legacy directories for backward compatibility self.nn_models_dir = self.base_dir / "NN" / "models" self.legacy_models_dir = self.base_dir / "models" # Metadata and checkpoint management self.metadata_file = self.checkpoints_dir / "model_metadata.json" self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json" # Initialize storage self._initialize_directories() self.metadata = self._load_metadata() self.checkpoint_metadata = self._load_checkpoint_metadata() logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}") def _get_default_config(self) -> Dict[str, Any]: """Get default configuration""" return { 'max_checkpoints_per_model': 5, 'cleanup_old_models': True, 'auto_archive': True, 'wandb_enabled': WANDB_AVAILABLE, 'checkpoint_retention_days': 30 } def _initialize_directories(self): """Initialize directory structure""" directories = [ self.checkpoints_dir, self.models_dir, self.saved_dir, self.best_models_dir, self.archive_dir ] + list(self.model_dirs.values()) for directory in directories: directory.mkdir(parents=True, exist_ok=True) def _load_metadata(self) -> Dict[str, Any]: """Load model metadata""" if self.metadata_file.exists(): try: with open(self.metadata_file, 'r') as f: return json.load(f) except Exception as e: logger.error(f"Error loading metadata: {e}") return {'models': {}, 'last_updated': datetime.now().isoformat()} def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]: """Load checkpoint metadata""" if self.checkpoint_metadata_file.exists(): try: with open(self.checkpoint_metadata_file, 'r') as f: data = json.load(f) # Convert dict values back to CheckpointMetadata objects result = {} for key, checkpoints in data.items(): result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints] return result except Exception as e: logger.error(f"Error loading checkpoint metadata: {e}") return defaultdict(list) def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: """Save a model checkpoint with enhanced error handling and validation""" try: performance_score = self._calculate_performance_score(performance_metrics) if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") return None # Create checkpoint directory checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) # Generate checkpoint filename timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') checkpoint_id = f"{model_name}_{timestamp}" filename = f"{checkpoint_id}.pt" filepath = checkpoint_dir / filename # Save model save_dict = { 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {}, 'model_class': model.__class__.__name__, 'checkpoint_id': checkpoint_id, 'model_name': model_name, 'model_type': model_type, 'performance_score': performance_score, 'performance_metrics': performance_metrics, 'training_metadata': training_metadata or {}, 'created_at': datetime.now().isoformat(), 'version': '2.0' } torch.save(save_dict, filepath) # Create checkpoint metadata file_size_mb = filepath.stat().st_size / (1024 * 1024) metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=model_name, model_type=model_type, file_path=str(filepath), created_at=datetime.now(), file_size_mb=file_size_mb, performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), val_accuracy=performance_metrics.get('val_accuracy'), val_loss=performance_metrics.get('val_loss'), reward=performance_metrics.get('reward'), pnl=performance_metrics.get('pnl'), epoch=performance_metrics.get('epoch'), training_time_hours=performance_metrics.get('training_time_hours'), total_parameters=performance_metrics.get('total_parameters') ) # Store metadata self.checkpoint_metadata[model_name].append(metadata) self._save_checkpoint_metadata() # Rotate checkpoints if needed self._rotate_checkpoints(model_name) # Upload to W&B if enabled if self.config.get('wandb_enabled'): self._upload_to_wandb(metadata) logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})") return metadata except Exception as e: logger.error(f"Error saving checkpoint for {model_name}: {e}") return None def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: """Calculate performance score from metrics""" # Simple weighted score - can be enhanced weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1} score = 0.0 for metric, weight in weights.items(): if metric in metrics: score += metrics[metric] * weight return score def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: """Determine if checkpoint should be saved""" existing_checkpoints = self.checkpoint_metadata.get(model_name, []) if not existing_checkpoints: return True # Keep if better than worst checkpoint or if we have fewer than max max_checkpoints = self.config.get('max_checkpoints_per_model', 5) if len(existing_checkpoints) < max_checkpoints: return True worst_score = min(cp.performance_score for cp in existing_checkpoints) return performance_score > worst_score def _rotate_checkpoints(self, model_name: str): """Rotate checkpoints to maintain max count""" checkpoints = self.checkpoint_metadata.get(model_name, []) max_checkpoints = self.config.get('max_checkpoints_per_model', 5) if len(checkpoints) <= max_checkpoints: return # Sort by performance score (descending) checkpoints.sort(key=lambda x: x.performance_score, reverse=True) # Remove excess checkpoints to_remove = checkpoints[max_checkpoints:] for checkpoint in to_remove: try: Path(checkpoint.file_path).unlink(missing_ok=True) logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}") except Exception as e: logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}") # Update metadata self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints] self._save_checkpoint_metadata() def _save_checkpoint_metadata(self): """Save checkpoint metadata to file""" try: data = {} for model_name, checkpoints in self.checkpoint_metadata.items(): data[model_name] = [cp.to_dict() for cp in checkpoints] with open(self.checkpoint_metadata_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}") def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]: """Upload checkpoint to W&B""" if not WANDB_AVAILABLE: return None try: # This would be implemented based on your W&B workflow logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}") return None except Exception as e: logger.error(f"Error uploading to W&B: {e}") return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: """Load the best checkpoint for a model""" try: # First, try the unified registry model_info = self.metadata['models'].get(model_name) if model_info and Path(model_info['latest_path']).exists(): # Load from unified registry load_dict = torch.load(model_info['latest_path'], map_location='cpu') return model_info['latest_path'], None # Fallback to checkpoint metadata checkpoints = self.checkpoint_metadata.get(model_name, []) if not checkpoints: logger.warning(f"No checkpoints found for {model_name}") return None # Get best checkpoint best_checkpoint = max(checkpoints, key=lambda x: x.performance_score) if not Path(best_checkpoint.file_path).exists(): logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}") return None return best_checkpoint.file_path, best_checkpoint except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None def get_storage_stats(self) -> Dict[str, Any]: """Get storage statistics""" try: total_size = 0 file_count = 0 for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]: if directory.exists(): for file_path in directory.rglob('*'): if file_path.is_file(): total_size += file_path.stat().st_size file_count += 1 return { 'total_size_mb': total_size / (1024 * 1024), 'file_count': file_count, 'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0 } except Exception as e: logger.error(f"Error getting storage stats: {e}") return {'error': str(e)} def get_checkpoint_stats(self) -> Dict[str, Any]: """Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)""" try: stats = { 'total_models': 0, 'total_checkpoints': 0, 'total_size_mb': 0.0, 'models': {} } # Count files in different directories as "checkpoints" checkpoint_dirs = [ self.checkpoints_dir / "cnn", self.checkpoints_dir / "dqn", self.checkpoints_dir / "rl", self.checkpoints_dir / "transformer", self.checkpoints_dir / "hybrid" ] total_size = 0 total_files = 0 for checkpoint_dir in checkpoint_dirs: if checkpoint_dir.exists(): model_files = list(checkpoint_dir.rglob('*.pt')) if model_files: model_name = checkpoint_dir.name stats['total_models'] += 1 model_size = sum(f.stat().st_size for f in model_files) stats['total_checkpoints'] += len(model_files) stats['total_size_mb'] += model_size / (1024 * 1024) total_size += model_size total_files += len(model_files) # Get the most recent file as "latest" latest_file = max(model_files, key=lambda f: f.stat().st_mtime) stats['models'][model_name] = { 'checkpoint_count': len(model_files), 'total_size_mb': model_size / (1024 * 1024), 'best_performance': 0.0, # Not tracked in unified system 'best_checkpoint_id': latest_file.name, 'latest_checkpoint': latest_file.name } # Also check saved models directory if self.saved_dir.exists(): saved_files = list(self.saved_dir.rglob('*.pt')) if saved_files: stats['total_checkpoints'] += len(saved_files) saved_size = sum(f.stat().st_size for f in saved_files) stats['total_size_mb'] += saved_size / (1024 * 1024) return stats except Exception as e: logger.error(f"Error getting checkpoint stats: {e}") return { 'total_models': 0, 'total_checkpoints': 0, 'total_size_mb': 0.0, 'models': {}, 'error': str(e) } def get_model_leaderboard(self) -> List[Dict[str, Any]]: """Get model performance leaderboard""" try: leaderboard = [] for model_name, model_info in self.metadata['models'].items(): if 'metrics' in model_info: metrics = ModelMetrics(**model_info['metrics']) leaderboard.append({ 'model_name': model_name, 'model_type': model_info.get('model_type', 'unknown'), 'composite_score': metrics.get_composite_score(), 'accuracy': metrics.accuracy, 'profit_factor': metrics.profit_factor, 'win_rate': metrics.win_rate, 'last_updated': model_info.get('last_saved', 'unknown') }) # Sort by composite score leaderboard.sort(key=lambda x: x['composite_score'], reverse=True) return leaderboard except Exception as e: logger.error(f"Error getting leaderboard: {e}") return [] # ===== LEGACY COMPATIBILITY FUNCTIONS ===== def create_model_manager() -> ModelManager: """Create and return a ModelManager instance""" return ModelManager() def save_model(model: Any, model_name: str, model_type: str = 'cnn', metadata: Optional[Dict[str, Any]] = None) -> bool: """Legacy compatibility function to save a model""" manager = create_model_manager() return manager.save_model(model, model_name, model_type, metadata) def load_model(model_name: str, model_type: str = 'cnn', model_class: Optional[Any] = None) -> Optional[Any]: """Legacy compatibility function to load a model""" manager = create_model_manager() return manager.load_model(model_name, model_type, model_class) def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: """Legacy compatibility function to save a checkpoint""" manager = create_model_manager() return manager.save_checkpoint(model, model_name, model_type, performance_metrics, training_metadata, force_save) def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: """Legacy compatibility function to load the best checkpoint""" manager = create_model_manager() return manager.load_best_checkpoint(model_name) # ===== EXAMPLE USAGE ===== if __name__ == "__main__": # Example usage of the unified model manager manager = create_model_manager() print(f"ModelManager initialized at: {manager.checkpoints_dir}") # Get storage stats stats = manager.get_storage_stats() print(f"Storage stats: {stats}") # Get leaderboard leaderboard = manager.get_model_leaderboard() print(f"Models in leaderboard: {len(leaderboard)}")