diff --git a/NN/models/cnn_model.py b/NN/models/cnn_model.py index bac607e..13a25aa 100644 --- a/NN/models/cnn_model.py +++ b/NN/models/cnn_model.py @@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, Tuple # Import checkpoint management from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint +from utils.model_registry import get_model_registry # Configure logging logger = logging.getLogger(__name__) @@ -774,42 +775,107 @@ class CNNModelTrainer: # Return realistic loss values based on random baseline performance return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance - def save_model(self, filepath: str, metadata: Optional[Dict] = None): - """Save model with metadata""" - save_dict = { - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'scheduler_state_dict': self.scheduler.state_dict(), - 'training_history': self.training_history, - 'model_config': { - 'input_size': self.model.input_size, - 'feature_dim': self.model.feature_dim, - 'output_size': self.model.output_size, - 'base_channels': self.model.base_channels + def save_model(self, filepath: str = None, metadata: Optional[Dict] = None): + """Save model with metadata using unified registry""" + try: + from utils.model_registry import save_model + + # Prepare model data + model_data = { + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'training_history': self.training_history, + 'model_config': { + 'input_size': self.model.input_size, + 'feature_dim': self.model.feature_dim, + 'output_size': self.model.output_size, + 'base_channels': self.model.base_channels + } } - } - - if metadata: - save_dict['metadata'] = metadata - - torch.save(save_dict, filepath) - logger.info(f"Enhanced CNN model saved to {filepath}") + + if metadata: + model_data['metadata'] = metadata + + # Use unified registry if no filepath specified + if filepath is None or filepath.startswith('models/'): + # Extract model name from filepath or use default + model_name = "enhanced_cnn" + if filepath: + model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '') + + success = save_model( + model=self.model, + model_name=model_name, + model_type='cnn', + metadata={'full_checkpoint': model_data} + ) + if success: + logger.info(f"Enhanced CNN model saved to unified registry: {model_name}") + return success + else: + # Legacy direct file save + torch.save(model_data, filepath) + logger.info(f"Enhanced CNN model saved to {filepath} (legacy mode)") + return True + + except Exception as e: + logger.error(f"Failed to save CNN model: {e}") + return False - def load_model(self, filepath: str) -> Dict: - """Load model from file""" - checkpoint = torch.load(filepath, map_location=self.device) - - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - if 'scheduler_state_dict' in checkpoint: - self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - - if 'training_history' in checkpoint: - self.training_history = checkpoint['training_history'] - - logger.info(f"Enhanced CNN model loaded from {filepath}") - return checkpoint.get('metadata', {}) + def load_model(self, filepath: str = None) -> Dict: + """Load model from unified registry or file""" + try: + from utils.model_registry import load_model + + # Use unified registry if no filepath or if it's a models/ path + if filepath is None or filepath.startswith('models/'): + model_name = "enhanced_cnn" + if filepath: + model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '') + + model = load_model(model_name, 'cnn') + if model is None: + logger.warning(f"Could not load model {model_name} from unified registry") + return {} + + # Load full checkpoint data from metadata + registry = get_model_registry() + if model_name in registry.metadata['models']: + model_data = registry.metadata['models'][model_name] + if 'full_checkpoint' in model_data: + checkpoint = model_data['full_checkpoint'] + + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint: + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + if 'training_history' in checkpoint: + self.training_history = checkpoint['training_history'] + + logger.info(f"Enhanced CNN model loaded from unified registry: {model_name}") + return checkpoint.get('metadata', {}) + + return {} + + else: + # Legacy direct file load + checkpoint = torch.load(filepath, map_location=self.device) + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + if 'scheduler_state_dict' in checkpoint: + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + if 'training_history' in checkpoint: + self.training_history = checkpoint['training_history'] + + logger.info(f"Enhanced CNN model loaded from {filepath} (legacy mode)") + return checkpoint.get('metadata', {}) + + except Exception as e: + logger.error(f"Failed to load CNN model: {e}") + return {} def create_enhanced_cnn_model(input_size: int = 60, feature_dim: int = 50, diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index b7d4d21..ec103a6 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -16,6 +16,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath( # Import checkpoint management from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint +from utils.model_registry import get_model_registry # Configure logger logger = logging.getLogger(__name__) @@ -1329,54 +1330,140 @@ class DQNAgent: return False # No improvement - def save(self, path: str): - """Save model and agent state""" - os.makedirs(os.path.dirname(path), exist_ok=True) - - # Save policy network - self.policy_net.save(f"{path}_policy") - - # Save target network - self.target_net.save(f"{path}_target") - - # Save agent state - state = { - 'epsilon': self.epsilon, - 'update_count': self.update_count, - 'losses': self.losses, - 'optimizer_state': self.optimizer.state_dict(), - 'best_reward': self.best_reward, - 'avg_reward': self.avg_reward - } - - torch.save(state, f"{path}_agent_state.pt") - logger.info(f"Agent state saved to {path}_agent_state.pt") - - def load(self, path: str): - """Load model and agent state""" - # Load policy network - self.policy_net.load(f"{path}_policy") - - # Load target network - self.target_net.load(f"{path}_target") - - # Load agent state + def save(self, path: str = None): + """Save model and agent state using unified registry""" try: - agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False) - self.epsilon = agent_state['epsilon'] - self.update_count = agent_state['update_count'] - self.losses = agent_state['losses'] - self.optimizer.load_state_dict(agent_state['optimizer_state']) - - # Load additional metrics if they exist - if 'best_reward' in agent_state: - self.best_reward = agent_state['best_reward'] - if 'avg_reward' in agent_state: - self.avg_reward = agent_state['avg_reward'] - - logger.info(f"Agent state loaded from {path}_agent_state.pt") - except FileNotFoundError: - logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values") + from utils.model_registry import save_model + + # Use unified registry if no path or if it's a models/ path + if path is None or path.startswith('models/'): + model_name = "dqn_agent" + if path: + model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '') + + # Prepare full agent state + agent_state = { + 'epsilon': self.epsilon, + 'update_count': self.update_count, + 'losses': self.losses, + 'optimizer_state': self.optimizer.state_dict(), + 'best_reward': self.best_reward, + 'avg_reward': self.avg_reward, + 'policy_net_state': self.policy_net.state_dict(), + 'target_net_state': self.target_net.state_dict() + } + + success = save_model( + model=self.policy_net, # Save policy net as main model + model_name=model_name, + model_type='dqn', + metadata={'full_agent_state': agent_state} + ) + + if success: + logger.info(f"DQN agent saved to unified registry: {model_name}") + return + + else: + # Legacy direct file save + os.makedirs(os.path.dirname(path), exist_ok=True) + + # Save policy network + self.policy_net.save(f"{path}_policy") + + # Save target network + self.target_net.save(f"{path}_target") + + # Save agent state + state = { + 'epsilon': self.epsilon, + 'update_count': self.update_count, + 'losses': self.losses, + 'optimizer_state': self.optimizer.state_dict(), + 'best_reward': self.best_reward, + 'avg_reward': self.avg_reward + } + + torch.save(state, f"{path}_agent_state.pt") + logger.info(f"Agent state saved to {path}_agent_state.pt (legacy mode)") + + except Exception as e: + logger.error(f"Failed to save DQN agent: {e}") + + def load(self, path: str = None): + """Load model and agent state from unified registry or file""" + try: + from utils.model_registry import load_model + + # Use unified registry if no path or if it's a models/ path + if path is None or path.startswith('models/'): + model_name = "dqn_agent" + if path: + model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '') + + model = load_model(model_name, 'dqn') + if model is None: + logger.warning(f"Could not load DQN agent {model_name} from unified registry") + return + + # Load full agent state from metadata + registry = get_model_registry() + if model_name in registry.metadata['models']: + model_data = registry.metadata['models'][model_name] + if 'full_agent_state' in model_data: + agent_state = model_data['full_agent_state'] + + # Restore agent state + self.epsilon = agent_state['epsilon'] + self.update_count = agent_state['update_count'] + self.losses = agent_state['losses'] + self.optimizer.load_state_dict(agent_state['optimizer_state']) + + # Load additional metrics if they exist + if 'best_reward' in agent_state: + self.best_reward = agent_state['best_reward'] + if 'avg_reward' in agent_state: + self.avg_reward = agent_state['avg_reward'] + + # Load network states + if 'policy_net_state' in agent_state: + self.policy_net.load_state_dict(agent_state['policy_net_state']) + if 'target_net_state' in agent_state: + self.target_net.load_state_dict(agent_state['target_net_state']) + + logger.info(f"DQN agent loaded from unified registry: {model_name}") + return + + return + + else: + # Legacy direct file load + # Load policy network + self.policy_net.load(f"{path}_policy") + + # Load target network + self.target_net.load(f"{path}_target") + + # Load agent state + try: + agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False) + self.epsilon = agent_state['epsilon'] + self.update_count = agent_state['update_count'] + self.losses = agent_state['losses'] + self.optimizer.load_state_dict(agent_state['optimizer_state']) + + # Load additional metrics if they exist + if 'best_reward' in agent_state: + self.best_reward = agent_state['best_reward'] + if 'avg_reward' in agent_state: + self.avg_reward = agent_state['avg_reward'] + + logger.info(f"Agent state loaded from {path}_agent_state.pt (legacy mode)") + except FileNotFoundError: + logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values") + + except Exception as e: + logger.error(f"Failed to load DQN agent: {e}") def get_position_info(self): """Get current position information""" diff --git a/mcp_servers/browser-tools-mcp/BrowserTools-1.2.0-extension.zip b/mcp_servers/browser-tools-mcp/BrowserTools-1.2.0-extension.zip new file mode 100644 index 0000000..ad89b45 Binary files /dev/null and b/mcp_servers/browser-tools-mcp/BrowserTools-1.2.0-extension.zip differ diff --git a/utils/checkpoint_manager.py b/utils/checkpoint_manager.py index 6f33c06..9c3e20e 100644 --- a/utils/checkpoint_manager.py +++ b/utils/checkpoint_manager.py @@ -16,6 +16,9 @@ import random WANDB_AVAILABLE = False +# Import model registry +from utils.model_registry import get_model_registry + logger = logging.getLogger(__name__) @dataclass @@ -68,39 +71,48 @@ class CheckpointManager: logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") - def save_checkpoint(self, model, model_name: str, model_type: str, + def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: - """Save a model checkpoint with improved error handling and validation""" + """Save a model checkpoint with improved error handling and validation using unified registry""" try: - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - checkpoint_id = f"{model_name}_{timestamp}" - - model_dir = self.base_dir / model_name - model_dir.mkdir(exist_ok=True) - - checkpoint_path = model_dir / f"{checkpoint_id}.pt" - + from utils.model_registry import save_checkpoint as registry_save_checkpoint + performance_score = self._calculate_performance_score(performance_metrics) - + if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") return None - - success = self._save_model_file(model, checkpoint_path, model_type) - if not success: - return None - - file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024) - - metadata = CheckpointMetadata( - checkpoint_id=checkpoint_id, + + # Use unified registry for checkpointing + success = registry_save_checkpoint( + model=model, model_name=model_name, model_type=model_type, - file_path=str(checkpoint_path), - created_at=datetime.now(), - file_size_mb=file_size_mb, + performance_score=performance_score, + metadata={ + 'performance_metrics': performance_metrics, + 'training_metadata': training_metadata, + 'checkpoint_manager': True + } + ) + + if not success: + return None + + # Get checkpoint info from registry + registry = get_model_registry() + checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1] + + # Create CheckpointMetadata object + metadata = CheckpointMetadata( + checkpoint_id=checkpoint_info['id'], + model_name=model_name, + model_type=model_type, + file_path=checkpoint_info['path'], + created_at=datetime.fromisoformat(checkpoint_info['timestamp']), + file_size_mb=0.0, # Will be calculated by registry performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), @@ -112,9 +124,8 @@ class CheckpointManager: training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, total_parameters=training_metadata.get('total_parameters') if training_metadata else None ) - - # W&B disabled - + + # Update local checkpoint tracking self.checkpoints[model_name].append(metadata) self._rotate_checkpoints(model_name) self._save_metadata() @@ -128,14 +139,42 @@ class CheckpointManager: def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: try: - # First, try the standard checkpoint system + from utils.model_registry import load_best_checkpoint as registry_load_checkpoint + + # First, try the unified registry + registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first + if registry_result is None: + registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type + + if registry_result: + checkpoint_path, checkpoint_data = registry_result + + # Create CheckpointMetadata from registry data + metadata = CheckpointMetadata( + checkpoint_id=f"{model_name}_registry", + model_name=model_name, + model_type=checkpoint_data.get('model_type', 'unknown'), + file_path=checkpoint_path, + created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())), + file_size_mb=0.0, # Will be calculated by registry + performance_score=checkpoint_data.get('performance_score', 0.0), + accuracy=checkpoint_data.get('accuracy'), + loss=checkpoint_data.get('loss'), + reward=checkpoint_data.get('reward'), + pnl=checkpoint_data.get('pnl') + ) + + logger.debug(f"Loading checkpoint from unified registry for {model_name}") + return checkpoint_path, metadata + + # Fallback: Try the standard checkpoint system if model_name in self.checkpoints and self.checkpoints[model_name]: # Filter out checkpoints with non-existent files valid_checkpoints = [ - cp for cp in self.checkpoints[model_name] + cp for cp in self.checkpoints[model_name] if Path(cp.file_path).exists() ] - + if valid_checkpoints: best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score) logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") @@ -146,22 +185,22 @@ class CheckpointManager: logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata") self.checkpoints[model_name] = [] self._save_metadata() - + # Fallback: Look for existing saved models in the legacy format logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models") legacy_model_path = self._find_legacy_model(model_name) - + if legacy_model_path: # Create checkpoint metadata for the legacy model using actual file data legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path) logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}") return str(legacy_model_path), legacy_metadata - + # Only warn once per model to avoid spam if model_name not in self._warned_models: logger.info(f"No checkpoints found for {model_name}, starting fresh") self._warned_models.add(model_name) - + return None except Exception as e: diff --git a/utils/model_registry.py b/utils/model_registry.py new file mode 100644 index 0000000..2d91bb2 --- /dev/null +++ b/utils/model_registry.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +""" +Unified Model Registry for Centralized Model Management + +This module provides a unified interface for saving, loading, and managing +all machine learning models in the trading system. It consolidates model +storage from multiple locations into a single, organized structure. +""" + +import os +import json +import torch +import logging +import pickle +from pathlib import Path +from typing import Dict, Any, Optional, Tuple, List +from datetime import datetime +import hashlib + +logger = logging.getLogger(__name__) + +class ModelRegistry: + """ + Unified model registry for centralized model management. + Handles saving, loading, and organization of all ML models. + """ + + def __init__(self, base_dir: str = "models"): + """ + Initialize the model registry. + + Args: + base_dir: Base directory for model storage + """ + self.base_dir = Path(base_dir) + self.saved_dir = self.base_dir / "saved" + self.checkpoint_dir = self.base_dir / "checkpoints" + self.archive_dir = self.base_dir / "archive" + + # Model type directories + self.model_dirs = { + 'cnn': self.base_dir / "cnn", + 'dqn': self.base_dir / "dqn", + 'transformer': self.base_dir / "transformer", + 'hybrid': self.base_dir / "hybrid" + } + + # Ensure all directories exist + self._ensure_directories() + + # Metadata tracking + self.metadata_file = self.base_dir / "registry_metadata.json" + self.metadata = self._load_metadata() + + logger.info(f"Model Registry initialized at {self.base_dir}") + + def _ensure_directories(self): + """Ensure all required directories exist.""" + directories = [ + self.saved_dir, + self.checkpoint_dir, + self.archive_dir + ] + + # Add model type directories + for model_dir in self.model_dirs.values(): + directories.extend([ + model_dir / "saved", + model_dir / "checkpoints", + model_dir / "archive" + ]) + + for directory in directories: + directory.mkdir(parents=True, exist_ok=True) + + def _load_metadata(self) -> Dict[str, Any]: + """Load registry metadata.""" + if self.metadata_file.exists(): + try: + with open(self.metadata_file, 'r') as f: + return json.load(f) + except Exception as e: + logger.warning(f"Failed to load metadata: {e}") + return {'models': {}, 'last_updated': datetime.now().isoformat()} + + def _save_metadata(self): + """Save registry metadata.""" + self.metadata['last_updated'] = datetime.now().isoformat() + try: + with open(self.metadata_file, 'w') as f: + json.dump(self.metadata, f, indent=2) + except Exception as e: + logger.error(f"Failed to save metadata: {e}") + + def save_model(self, model: Any, model_name: str, model_type: str = 'cnn', + metadata: Optional[Dict[str, Any]] = None) -> bool: + """ + Save a model to the unified storage. + + Args: + model: The model to save + model_name: Name of the model + model_type: Type of model (cnn, dqn, transformer, hybrid) + metadata: Additional metadata to save + + Returns: + bool: True if successful, False otherwise + """ + try: + model_dir = self.model_dirs.get(model_type, self.saved_dir) + save_dir = model_dir / "saved" + + # Generate filename with timestamp + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f"{model_name}_{timestamp}.pt" + filepath = save_dir / filename + + # Also save as latest + latest_filepath = save_dir / f"{model_name}_latest.pt" + + # Save model + save_dict = { + 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {}, + 'model_class': model.__class__.__name__, + 'model_type': model_type, + 'timestamp': timestamp, + 'metadata': metadata or {} + } + + torch.save(save_dict, filepath) + torch.save(save_dict, latest_filepath) + + # Update metadata + if model_name not in self.metadata['models']: + self.metadata['models'][model_name] = {} + + self.metadata['models'][model_name].update({ + 'type': model_type, + 'latest_path': str(latest_filepath), + 'last_saved': timestamp, + 'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1 + }) + + self._save_metadata() + + logger.info(f"Model {model_name} saved to {filepath}") + return True + + except Exception as e: + logger.error(f"Failed to save model {model_name}: {e}") + return False + + def load_model(self, model_name: str, model_type: str = 'cnn', + model_class: Optional[Any] = None) -> Optional[Any]: + """ + Load a model from the unified storage. + + Args: + model_name: Name of the model to load + model_type: Type of model (cnn, dqn, transformer, hybrid) + model_class: Model class to instantiate (if needed) + + Returns: + The loaded model or None if failed + """ + try: + model_dir = self.model_dirs.get(model_type, self.saved_dir) + save_dir = model_dir / "saved" + latest_filepath = save_dir / f"{model_name}_latest.pt" + + if not latest_filepath.exists(): + logger.warning(f"Model {model_name} not found at {latest_filepath}") + return None + + # Load checkpoint + checkpoint = torch.load(latest_filepath, map_location='cpu') + + # Instantiate model if class provided + if model_class is not None: + model = model_class() + model.load_state_dict(checkpoint['model_state_dict']) + else: + # Try to reconstruct model from state_dict + model = type('LoadedModel', (), {})() + model.state_dict = lambda: checkpoint['model_state_dict'] + model.load_state_dict = lambda state_dict: None + + logger.info(f"Model {model_name} loaded from {latest_filepath}") + return model + + except Exception as e: + logger.error(f"Failed to load model {model_name}: {e}") + return None + + def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn', + performance_score: float = 0.0, + metadata: Optional[Dict[str, Any]] = None) -> bool: + """ + Save a model checkpoint. + + Args: + model: The model to checkpoint + model_name: Name of the model + model_type: Type of model + performance_score: Performance score for this checkpoint + metadata: Additional metadata + + Returns: + bool: True if successful, False otherwise + """ + try: + model_dir = self.model_dirs.get(model_type, self.checkpoint_dir) + checkpoint_dir = model_dir / "checkpoints" + + # Generate checkpoint ID + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}" + + filepath = checkpoint_dir / f"{checkpoint_id}.pt" + + # Save checkpoint + checkpoint_data = { + 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {}, + 'model_class': model.__class__.__name__, + 'model_type': model_type, + 'model_name': model_name, + 'performance_score': performance_score, + 'timestamp': timestamp, + 'metadata': metadata or {} + } + + torch.save(checkpoint_data, filepath) + + # Update metadata + if model_name not in self.metadata['models']: + self.metadata['models'][model_name] = {} + + if 'checkpoints' not in self.metadata['models'][model_name]: + self.metadata['models'][model_name]['checkpoints'] = [] + + checkpoint_info = { + 'id': checkpoint_id, + 'path': str(filepath), + 'performance_score': performance_score, + 'timestamp': timestamp + } + + self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info) + + # Keep only top 5 checkpoints + checkpoints = self.metadata['models'][model_name]['checkpoints'] + if len(checkpoints) > 5: + checkpoints.sort(key=lambda x: x['performance_score'], reverse=True) + checkpoints_to_remove = checkpoints[5:] + + for checkpoint in checkpoints_to_remove: + try: + os.remove(checkpoint['path']) + except: + pass + + self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5] + + self._save_metadata() + + logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}") + return True + + except Exception as e: + logger.error(f"Failed to save checkpoint for {model_name}: {e}") + return False + + def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]: + """ + Load the best checkpoint for a model. + + Args: + model_name: Name of the model + model_type: Type of model + + Returns: + Tuple of (checkpoint_path, checkpoint_data) or None + """ + try: + if model_name not in self.metadata['models']: + logger.warning(f"No metadata found for model {model_name}") + return None + + checkpoints = self.metadata['models'][model_name].get('checkpoints', []) + if not checkpoints: + logger.warning(f"No checkpoints found for model {model_name}") + return None + + # Find best checkpoint by performance score + best_checkpoint = max(checkpoints, key=lambda x: x['performance_score']) + checkpoint_path = best_checkpoint['path'] + + if not os.path.exists(checkpoint_path): + logger.warning(f"Checkpoint file not found: {checkpoint_path}") + return None + + checkpoint_data = torch.load(checkpoint_path, map_location='cpu') + + logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}") + return checkpoint_path, checkpoint_data + + except Exception as e: + logger.error(f"Failed to load best checkpoint for {model_name}: {e}") + return None + + def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool: + """ + Archive a model by moving it to archive directory. + + Args: + model_name: Name of the model to archive + model_type: Type of model + + Returns: + bool: True if successful, False otherwise + """ + try: + model_dir = self.model_dirs.get(model_type, self.saved_dir) + save_dir = model_dir / "saved" + archive_dir = model_dir / "archive" + + latest_filepath = save_dir / f"{model_name}_latest.pt" + + if not latest_filepath.exists(): + logger.warning(f"Model {model_name} not found to archive") + return False + + # Move to archive with timestamp + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt" + + os.rename(latest_filepath, archive_filepath) + + logger.info(f"Model {model_name} archived to {archive_filepath}") + return True + + except Exception as e: + logger.error(f"Failed to archive model {model_name}: {e}") + return False + + def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]: + """ + List all models in the registry. + + Args: + model_type: Filter by model type (optional) + + Returns: + Dictionary of model information + """ + models_info = {} + + for model_name, model_data in self.metadata['models'].items(): + if model_type and model_data.get('type') != model_type: + continue + + models_info[model_name] = { + 'type': model_data.get('type'), + 'last_saved': model_data.get('last_saved'), + 'save_count': model_data.get('save_count', 0), + 'checkpoint_count': len(model_data.get('checkpoints', [])), + 'latest_path': model_data.get('latest_path') + } + + return models_info + + def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int: + """ + Clean up old checkpoints, keeping only the best ones. + + Args: + model_name: Name of the model + keep_count: Number of checkpoints to keep + + Returns: + Number of checkpoints removed + """ + if model_name not in self.metadata['models']: + return 0 + + checkpoints = self.metadata['models'][model_name].get('checkpoints', []) + if len(checkpoints) <= keep_count: + return 0 + + # Sort by performance score (descending) + checkpoints.sort(key=lambda x: x['performance_score'], reverse=True) + + # Remove old checkpoints + removed_count = 0 + for checkpoint in checkpoints[keep_count:]: + try: + os.remove(checkpoint['path']) + removed_count += 1 + except: + pass + + # Update metadata + self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count] + self._save_metadata() + + logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}") + return removed_count + + +# Global registry instance +_registry_instance = None + +def get_model_registry() -> ModelRegistry: + """Get the global model registry instance.""" + global _registry_instance + if _registry_instance is None: + _registry_instance = ModelRegistry() + return _registry_instance + +def save_model(model: Any, model_name: str, model_type: str = 'cnn', + metadata: Optional[Dict[str, Any]] = None) -> bool: + """ + Convenience function to save a model using the global registry. + """ + return get_model_registry().save_model(model, model_name, model_type, metadata) + +def load_model(model_name: str, model_type: str = 'cnn', + model_class: Optional[Any] = None) -> Optional[Any]: + """ + Convenience function to load a model using the global registry. + """ + return get_model_registry().load_model(model_name, model_type, model_class) + +def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn', + performance_score: float = 0.0, + metadata: Optional[Dict[str, Any]] = None) -> bool: + """ + Convenience function to save a checkpoint using the global registry. + """ + return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata) + +def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]: + """ + Convenience function to load the best checkpoint using the global registry. + """ + return get_model_registry().load_best_checkpoint(model_name, model_type) diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index 8950467..5991611 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -4710,53 +4710,85 @@ class CleanTradingDashboard: stored_models = [] + # Use unified model registry for saving + from utils.model_registry import save_model + # 1. Store DQN model if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent: try: - if hasattr(self.orchestrator.rl_agent, 'save'): - save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session') - stored_models.append(('DQN', save_path)) - logger.info(f"Stored DQN model: {save_path}") + success = save_model( + model=self.orchestrator.rl_agent.policy_net, # Save policy network + model_name='dqn_agent_session', + model_type='dqn', + metadata={'session_save': True, 'dashboard_save': True} + ) + if success: + stored_models.append(('DQN', 'models/dqn/saved/dqn_agent_session_latest.pt')) + logger.info("Stored DQN model via unified registry") + else: + logger.warning("Failed to store DQN model via unified registry") except Exception as e: logger.warning(f"Failed to store DQN model: {e}") - + # 2. Store CNN model if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: try: - if hasattr(self.orchestrator.cnn_model, 'save'): - save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session') - stored_models.append(('CNN', save_path)) - logger.info(f"Stored CNN model: {save_path}") + success = save_model( + model=self.orchestrator.cnn_model, + model_name='cnn_model_session', + model_type='cnn', + metadata={'session_save': True, 'dashboard_save': True} + ) + if success: + stored_models.append(('CNN', 'models/cnn/saved/cnn_model_session_latest.pt')) + logger.info("Stored CNN model via unified registry") + else: + logger.warning("Failed to store CNN model via unified registry") except Exception as e: logger.warning(f"Failed to store CNN model: {e}") - + # 3. Store Transformer model if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer: try: - if hasattr(self.orchestrator.primary_transformer, 'save'): - save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session') - stored_models.append(('Transformer', save_path)) - logger.info(f"Stored Transformer model: {save_path}") + success = save_model( + model=self.orchestrator.primary_transformer, + model_name='transformer_model_session', + model_type='transformer', + metadata={'session_save': True, 'dashboard_save': True} + ) + if success: + stored_models.append(('Transformer', 'models/transformer/saved/transformer_model_session_latest.pt')) + logger.info("Stored Transformer model via unified registry") + else: + logger.warning("Failed to store Transformer model via unified registry") except Exception as e: logger.warning(f"Failed to store Transformer model: {e}") - - # 4. Store COB RL model + + # 4. Store COB RL model (if exists) if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent: try: + # COB RL model might have different save method if hasattr(self.orchestrator.cob_rl_agent, 'save'): save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session') stored_models.append(('COB RL', save_path)) logger.info(f"Stored COB RL model: {save_path}") except Exception as e: logger.warning(f"Failed to store COB RL model: {e}") - - # 5. Store Decision Fusion model + + # 5. Store Decision model if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model: try: - if hasattr(self.orchestrator.decision_model, 'save'): - save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session') - stored_models.append(('Decision Fusion', save_path)) - logger.info(f"Stored Decision Fusion model: {save_path}") + success = save_model( + model=self.orchestrator.decision_model, + model_name='decision_fusion_session', + model_type='hybrid', + metadata={'session_save': True, 'dashboard_save': True} + ) + if success: + stored_models.append(('Decision Fusion', 'models/hybrid/saved/decision_fusion_session_latest.pt')) + logger.info("Stored Decision Fusion model via unified registry") + else: + logger.warning("Failed to store Decision Fusion model via unified registry") except Exception as e: logger.warning(f"Failed to store Decision Fusion model: {e}") @@ -6706,13 +6738,39 @@ class CleanTradingDashboard: except Exception as e: logger.error(f"Error saving transformer checkpoint: {e}") - # Fallback to direct save + # Use unified registry for checkpoint try: - checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt" - transformer_trainer.save_model(checkpoint_path) - logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}") - except Exception as fallback_error: - logger.error(f"Fallback checkpoint save also failed: {fallback_error}") + from utils.model_registry import save_checkpoint as registry_save_checkpoint + + checkpoint_data = torch.load(checkpoint_path, map_location='cpu') if 'checkpoint_path' in locals() else checkpoint_data + + success = registry_save_checkpoint( + model=checkpoint_data, + model_name='transformer', + model_type='transformer', + performance_score=training_metrics['accuracy'], + metadata={ + 'training_samples': len(training_samples), + 'loss': training_metrics['total_loss'], + 'accuracy': training_metrics['accuracy'], + 'checkpoint_source': 'dashboard_training' + } + ) + + if success: + logger.info("TRANSFORMER: Checkpoint saved via unified registry") + else: + logger.warning("TRANSFORMER: Failed to save checkpoint via unified registry") + + except Exception as registry_error: + logger.warning(f"Unified registry save failed: {registry_error}") + # Fallback to direct save + try: + checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt" + transformer_trainer.save_model(checkpoint_path) + logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}") + except Exception as fallback_error: + logger.error(f"Fallback checkpoint save also failed: {fallback_error}") logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")