#!/usr/bin/env python3 """ Checkpoint Management System for W&B Training """ import os import json import logging from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, asdict from collections import defaultdict import torch import random WANDB_AVAILABLE = False # Import model registry from utils.model_registry import get_model_registry logger = logging.getLogger(__name__) @dataclass class CheckpointMetadata: checkpoint_id: str model_name: str model_type: str file_path: str created_at: datetime file_size_mb: float performance_score: float accuracy: Optional[float] = None loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None wandb_run_id: Optional[str] = None wandb_artifact_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: data = asdict(self) data['created_at'] = self.created_at.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': data['created_at'] = datetime.fromisoformat(data['created_at']) return cls(**data) class CheckpointManager: def __init__(self, base_checkpoint_dir: str = "NN/models/saved", max_checkpoints_per_model: int = 5, metadata_file: str = "checkpoint_metadata.json", enable_wandb: bool = False): self.base_dir = Path(base_checkpoint_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.max_checkpoints = max_checkpoints_per_model self.metadata_file = self.base_dir / metadata_file self.enable_wandb = False self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list) self._warned_models = set() # Track models we've warned about to reduce spam self._load_metadata() logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: """Save a model checkpoint with improved error handling and validation using unified registry""" try: from utils.model_registry import save_checkpoint as registry_save_checkpoint performance_score = self._calculate_performance_score(performance_metrics) if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") return None # Use unified registry for checkpointing success = registry_save_checkpoint( model=model, model_name=model_name, model_type=model_type, performance_score=performance_score, metadata={ 'performance_metrics': performance_metrics, 'training_metadata': training_metadata, 'checkpoint_manager': True } ) if not success: return None # Get checkpoint info from registry registry = get_model_registry() checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1] # Create CheckpointMetadata object metadata = CheckpointMetadata( checkpoint_id=checkpoint_info['id'], model_name=model_name, model_type=model_type, file_path=checkpoint_info['path'], created_at=datetime.fromisoformat(checkpoint_info['timestamp']), file_size_mb=0.0, # Will be calculated by registry performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), val_accuracy=performance_metrics.get('val_accuracy'), val_loss=performance_metrics.get('val_loss'), reward=performance_metrics.get('reward'), pnl=performance_metrics.get('pnl'), epoch=training_metadata.get('epoch') if training_metadata else None, training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, total_parameters=training_metadata.get('total_parameters') if training_metadata else None ) # Update local checkpoint tracking self.checkpoints[model_name].append(metadata) self._rotate_checkpoints(model_name) self._save_metadata() logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})") return metadata except Exception as e: logger.error(f"Error saving checkpoint for {model_name}: {e}") return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: try: from utils.model_registry import load_best_checkpoint as registry_load_checkpoint # First, try the unified registry registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first if registry_result is None: registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type if registry_result: checkpoint_path, checkpoint_data = registry_result # Create CheckpointMetadata from registry data metadata = CheckpointMetadata( checkpoint_id=f"{model_name}_registry", model_name=model_name, model_type=checkpoint_data.get('model_type', 'unknown'), file_path=checkpoint_path, created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())), file_size_mb=0.0, # Will be calculated by registry performance_score=checkpoint_data.get('performance_score', 0.0), accuracy=checkpoint_data.get('accuracy'), loss=checkpoint_data.get('loss'), reward=checkpoint_data.get('reward'), pnl=checkpoint_data.get('pnl') ) logger.debug(f"Loading checkpoint from unified registry for {model_name}") return checkpoint_path, metadata # Fallback: Try the standard checkpoint system if model_name in self.checkpoints and self.checkpoints[model_name]: # Filter out checkpoints with non-existent files valid_checkpoints = [ cp for cp in self.checkpoints[model_name] if Path(cp.file_path).exists() ] if valid_checkpoints: best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score) logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") return best_checkpoint.file_path, best_checkpoint else: # Clean up invalid metadata entries invalid_count = len(self.checkpoints[model_name]) logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata") self.checkpoints[model_name] = [] self._save_metadata() # Fallback: Look for existing saved models in the legacy format logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models") legacy_model_path = self._find_legacy_model(model_name) if legacy_model_path: # Create checkpoint metadata for the legacy model using actual file data legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path) logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}") return str(legacy_model_path), legacy_metadata # Only warn once per model to avoid spam if model_name not in self._warned_models: logger.info(f"No checkpoints found for {model_name}, starting fresh") self._warned_models.add(model_name) return None except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: """Calculate performance score with improved sensitivity for training models""" score = 0.0 # Prioritize loss reduction for active training models if 'loss' in metrics: # Invert loss so lower loss = higher score, with better scaling loss_value = metrics['loss'] if loss_value > 0: score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes else: score += 100 # Perfect loss # Add other metrics with appropriate weights if 'accuracy' in metrics: score += metrics['accuracy'] * 50 # Reduced weight to balance with loss if 'val_accuracy' in metrics: score += metrics['val_accuracy'] * 50 if 'val_loss' in metrics: val_loss = metrics['val_loss'] if val_loss > 0: score += max(0, 50 / (1 + val_loss)) if 'reward' in metrics: score += metrics['reward'] * 10 if 'pnl' in metrics: score += metrics['pnl'] * 5 if 'training_samples' in metrics: # Bonus for processing more training samples score += min(10, metrics['training_samples'] / 10) # Return actual calculated score - NO SYNTHETIC MINIMUM return score def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: """Improved checkpoint saving logic with more frequent saves during training""" if model_name not in self.checkpoints or not self.checkpoints[model_name]: return True # Always save first checkpoint # Allow more checkpoints during active training if len(self.checkpoints[model_name]) < self.max_checkpoints: return True # Get current best and worst scores scores = [cp.performance_score for cp in self.checkpoints[model_name]] best_score = max(scores) worst_score = min(scores) # Save if better than worst (more frequent saves) if performance_score > worst_score: return True # For high-performing models (score > 100), be more sensitive to small improvements if best_score > 100: # Save if within 0.1% of best score (very sensitive for converged models) if performance_score >= best_score * 0.999: return True else: # Also save if we're within 10% of best score (capture near-optimal models) if performance_score >= best_score * 0.9: return True # Save more frequently during active training (every 5th attempt instead of 10th) if random.random() < 0.2: # 20% chance to save anyway logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training") return True return False def _save_model_file(self, model, file_path: Path, model_type: str) -> bool: try: if hasattr(model, 'state_dict'): torch.save({ 'model_state_dict': model.state_dict(), 'model_type': model_type, 'saved_at': datetime.now().isoformat() }, file_path) else: torch.save(model, file_path) return True except Exception as e: logger.error(f"Error saving model file {file_path}: {e}") return False def _rotate_checkpoints(self, model_name: str): checkpoint_list = self.checkpoints[model_name] if len(checkpoint_list) <= self.max_checkpoints: return checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True) to_remove = checkpoint_list[self.max_checkpoints:] self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints] for checkpoint in to_remove: try: file_path = Path(checkpoint.file_path) if file_path.exists(): file_path.unlink() logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}") except Exception as e: logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}") def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]: return None def _load_metadata(self): try: if self.metadata_file.exists(): with open(self.metadata_file, 'r') as f: data = json.load(f) for model_name, checkpoint_list in data.items(): self.checkpoints[model_name] = [ CheckpointMetadata.from_dict(cp_data) for cp_data in checkpoint_list ] logger.info(f"Loaded metadata for {len(self.checkpoints)} models") except Exception as e: logger.error(f"Error loading checkpoint metadata: {e}") def _save_metadata(self): try: data = {} for model_name, checkpoint_list in self.checkpoints.items(): data[model_name] = [cp.to_dict() for cp in checkpoint_list] with open(self.metadata_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}") def get_checkpoint_stats(self): """Get statistics about managed checkpoints""" stats = { 'total_models': len(self.checkpoints), 'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()), 'total_size_mb': 0.0, 'models': {} } for model_name, checkpoint_list in self.checkpoints.items(): if not checkpoint_list: continue model_size = sum(cp.file_size_mb for cp in checkpoint_list) best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score) stats['models'][model_name] = { 'checkpoint_count': len(checkpoint_list), 'total_size_mb': model_size, 'best_performance': best_checkpoint.performance_score, 'best_checkpoint_id': best_checkpoint.checkpoint_id, 'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id } stats['total_size_mb'] += model_size return stats def _find_legacy_model(self, model_name: str) -> Optional[Path]: """Find legacy saved models based on model name patterns""" base_dir = Path(self.base_dir) # Additional search locations search_dirs = [ base_dir, Path("models/saved"), Path("NN/models/saved"), Path("models"), Path("models/archive"), Path("models/backtest") ] # Define model name mappings and patterns for legacy files legacy_patterns = { 'dqn_agent': [ 'dqn_agent_session_policy.pt', 'dqn_agent_session_agent_state.pt', 'dqn_agent_best_policy.pt', 'enhanced_dqn_best_policy.pt', 'improved_dqn_agent_best_policy.pt', 'dqn_agent_final_policy.pt', 'trading_agent_best_pnl.pt' ], 'enhanced_cnn': [ 'cnn_model_session.pt', 'cnn_model_best.pt', 'optimized_short_term_model_best.pt', 'optimized_short_term_model_realtime_best.pt', 'optimized_short_term_model_ticks_best.pt' ], 'extrema_trainer': [ 'supervised_model_best.pt' ], 'cob_rl': [ 'best_rl_model.pth_policy.pt', 'rl_agent_best_policy.pt' ], 'decision': [ # Decision models might be in subdirectories, but let's check main dir too 'decision_best.pt', 'decision_model_best.pt', # Check for transformer models which might be used as decision models 'enhanced_dqn_best_policy.pt', 'improved_dqn_agent_best_policy.pt' ] } # Get patterns for this model name patterns = legacy_patterns.get(model_name, []) # Also try generic patterns based on model name patterns.extend([ f'{model_name}_best.pt', f'{model_name}_best_policy.pt', f'{model_name}_final.pt', f'{model_name}_final_policy.pt' ]) # Search for the model files in all search directories for search_dir in search_dirs: if not search_dir.exists(): continue for pattern in patterns: candidate_path = search_dir / pattern if candidate_path.exists(): logger.info(f"Found legacy model file: {candidate_path}") return candidate_path # Also check subdirectories for subdir in base_dir.iterdir(): if subdir.is_dir() and subdir.name == model_name: for pattern in patterns: candidate_path = subdir / pattern if candidate_path.exists(): logger.debug(f"Found legacy model file in subdirectory: {candidate_path}") return candidate_path # Extended search: scan common project model directories for best checkpoints try: # Attempt to infer project root from base_dir (NN/models/saved -> root) project_root = base_dir.resolve().parent.parent.parent except Exception: project_root = Path(".").resolve() additional_dirs = [ project_root / "models", project_root / "models" / "archive", project_root / "models" / "backtest", ] def _match_legacy_name(candidate: Path, model: str) -> bool: name = candidate.name.lower() model_keys = { 'dqn_agent': ['dqn', 'agent', 'policy'], 'enhanced_cnn': ['cnn', 'optimized_short_term'], 'extrema_trainer': ['supervised', 'extrema'], 'cob_rl': ['cob', 'rl', 'policy'], 'decision': ['decision', 'transformer'] }.get(model, [model]) return any(k in name for k in model_keys) candidates: List[Path] = [] for adir in additional_dirs: if not adir.exists(): continue try: for pt in adir.rglob('*.pt'): # Prefer files that indicate "best" and match model hints lname = pt.name.lower() if 'best' in lname and _match_legacy_name(pt, model_name): candidates.append(pt) # Do not add generic fallbacks to avoid mismatched model types except Exception: # Ignore directory traversal issues pass if candidates: # Pick the most recently modified candidate try: best = max(candidates, key=lambda p: p.stat().st_mtime) logger.debug(f"Found legacy model file in project models dir: {best}") return best except Exception: # If stat fails, just return the first one deterministically candidates.sort() logger.debug(f"Found legacy model file in project models dir: {candidates[0]}") return candidates[0] return None def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata: """Create metadata for legacy model files using only actual file information""" try: file_size_mb = file_path.stat().st_size / (1024 * 1024) created_time = datetime.fromtimestamp(file_path.stat().st_mtime) # NO SYNTHETIC DATA - use only actual file information return CheckpointMetadata( checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}", model_name=model_name, model_type=model_name, file_path=str(file_path), created_at=created_time, file_size_mb=file_size_mb, performance_score=0.0, # Unknown performance - use 0, not synthetic values accuracy=None, loss=None, val_accuracy=None, val_loss=None, reward=None, pnl=None, epoch=None, training_time_hours=None, total_parameters=None, wandb_run_id=None, wandb_artifact_name=None ) except Exception as e: logger.error(f"Error creating legacy metadata for {model_name}: {e}") # Return a basic metadata with minimal info - NO SYNTHETIC VALUES return CheckpointMetadata( checkpoint_id=f"legacy_{model_name}", model_name=model_name, model_type=model_name, file_path=str(file_path), created_at=datetime.now(), file_size_mb=0.0, performance_score=0.0 # Unknown - use 0, not synthetic ) _checkpoint_manager = None def get_checkpoint_manager() -> CheckpointManager: global _checkpoint_manager if _checkpoint_manager is None: _checkpoint_manager = CheckpointManager() return _checkpoint_manager def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: return get_checkpoint_manager().save_checkpoint( model, model_name, model_type, performance_metrics, training_metadata, force_save ) def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: return get_checkpoint_manager().load_best_checkpoint(model_name)