#!/usr/bin/env python3 """ Checkpoint Management System for W&B Training """ import os import json import logging from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, asdict from collections import defaultdict import torch import random try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False logger = logging.getLogger(__name__) @dataclass class CheckpointMetadata: checkpoint_id: str model_name: str model_type: str file_path: str created_at: datetime file_size_mb: float performance_score: float accuracy: Optional[float] = None loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None wandb_run_id: Optional[str] = None wandb_artifact_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: data = asdict(self) data['created_at'] = self.created_at.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': data['created_at'] = datetime.fromisoformat(data['created_at']) return cls(**data) class CheckpointManager: def __init__(self, base_checkpoint_dir: str = "NN/models/saved", max_checkpoints_per_model: int = 5, metadata_file: str = "checkpoint_metadata.json", enable_wandb: bool = True): self.base_dir = Path(base_checkpoint_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.max_checkpoints = max_checkpoints_per_model self.metadata_file = self.base_dir / metadata_file self.enable_wandb = enable_wandb and WANDB_AVAILABLE self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list) self._load_metadata() logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: try: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') checkpoint_id = f"{model_name}_{timestamp}" model_dir = self.base_dir / model_name model_dir.mkdir(exist_ok=True) checkpoint_path = model_dir / f"{checkpoint_id}.pt" performance_score = self._calculate_performance_score(performance_metrics) if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") return None success = self._save_model_file(model, checkpoint_path, model_type) if not success: return None file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024) metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=model_name, model_type=model_type, file_path=str(checkpoint_path), created_at=datetime.now(), file_size_mb=file_size_mb, performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), val_accuracy=performance_metrics.get('val_accuracy'), val_loss=performance_metrics.get('val_loss'), reward=performance_metrics.get('reward'), pnl=performance_metrics.get('pnl'), epoch=training_metadata.get('epoch') if training_metadata else None, training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, total_parameters=training_metadata.get('total_parameters') if training_metadata else None ) if self.enable_wandb and wandb.run is not None: artifact_name = self._upload_to_wandb(checkpoint_path, metadata) metadata.wandb_run_id = wandb.run.id metadata.wandb_artifact_name = artifact_name self.checkpoints[model_name].append(metadata) self._rotate_checkpoints(model_name) self._save_metadata() logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})") return metadata except Exception as e: logger.error(f"Error saving checkpoint for {model_name}: {e}") return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: try: # First, try the standard checkpoint system if model_name in self.checkpoints and self.checkpoints[model_name]: # Filter out checkpoints with non-existent files valid_checkpoints = [ cp for cp in self.checkpoints[model_name] if Path(cp.file_path).exists() ] if valid_checkpoints: best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score) logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") return best_checkpoint.file_path, best_checkpoint else: # Clean up invalid metadata entries invalid_count = len(self.checkpoints[model_name]) logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata") self.checkpoints[model_name] = [] self._save_metadata() # Fallback: Look for existing saved models in the legacy format logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models") legacy_model_path = self._find_legacy_model(model_name) if legacy_model_path: # Create checkpoint metadata for the legacy model using actual file data legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path) logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}") return str(legacy_model_path), legacy_metadata logger.warning(f"No checkpoints or legacy models found for: {model_name}") return None except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: """Calculate performance score with improved sensitivity for training models""" score = 0.0 # Prioritize loss reduction for active training models if 'loss' in metrics: # Invert loss so lower loss = higher score, with better scaling loss_value = metrics['loss'] if loss_value > 0: score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes else: score += 100 # Perfect loss # Add other metrics with appropriate weights if 'accuracy' in metrics: score += metrics['accuracy'] * 50 # Reduced weight to balance with loss if 'val_accuracy' in metrics: score += metrics['val_accuracy'] * 50 if 'val_loss' in metrics: val_loss = metrics['val_loss'] if val_loss > 0: score += max(0, 50 / (1 + val_loss)) if 'reward' in metrics: score += metrics['reward'] * 10 if 'pnl' in metrics: score += metrics['pnl'] * 5 if 'training_samples' in metrics: # Bonus for processing more training samples score += min(10, metrics['training_samples'] / 10) # Return actual calculated score - NO SYNTHETIC MINIMUM return score def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: """Improved checkpoint saving logic with more frequent saves during training""" if model_name not in self.checkpoints or not self.checkpoints[model_name]: return True # Always save first checkpoint # Allow more checkpoints during active training if len(self.checkpoints[model_name]) < self.max_checkpoints: return True # Get current best and worst scores scores = [cp.performance_score for cp in self.checkpoints[model_name]] best_score = max(scores) worst_score = min(scores) # Save if better than worst (more frequent saves) if performance_score > worst_score: return True # For high-performing models (score > 100), be more sensitive to small improvements if best_score > 100: # Save if within 0.1% of best score (very sensitive for converged models) if performance_score >= best_score * 0.999: return True else: # Also save if we're within 10% of best score (capture near-optimal models) if performance_score >= best_score * 0.9: return True # Save more frequently during active training (every 5th attempt instead of 10th) if random.random() < 0.2: # 20% chance to save anyway logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training") return True return False def _save_model_file(self, model, file_path: Path, model_type: str) -> bool: try: if hasattr(model, 'state_dict'): torch.save({ 'model_state_dict': model.state_dict(), 'model_type': model_type, 'saved_at': datetime.now().isoformat() }, file_path) else: torch.save(model, file_path) return True except Exception as e: logger.error(f"Error saving model file {file_path}: {e}") return False def _rotate_checkpoints(self, model_name: str): checkpoint_list = self.checkpoints[model_name] if len(checkpoint_list) <= self.max_checkpoints: return checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True) to_remove = checkpoint_list[self.max_checkpoints:] self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints] for checkpoint in to_remove: try: file_path = Path(checkpoint.file_path) if file_path.exists(): file_path.unlink() logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}") except Exception as e: logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}") def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]: try: if not self.enable_wandb or wandb.run is None: return None artifact_name = f"{metadata.model_name}_checkpoint" artifact = wandb.Artifact(artifact_name, type="model") artifact.add_file(str(file_path)) wandb.log_artifact(artifact) return artifact_name except Exception as e: logger.error(f"Error uploading to W&B: {e}") return None def _load_metadata(self): try: if self.metadata_file.exists(): with open(self.metadata_file, 'r') as f: data = json.load(f) for model_name, checkpoint_list in data.items(): self.checkpoints[model_name] = [ CheckpointMetadata.from_dict(cp_data) for cp_data in checkpoint_list ] logger.info(f"Loaded metadata for {len(self.checkpoints)} models") except Exception as e: logger.error(f"Error loading checkpoint metadata: {e}") def _save_metadata(self): try: data = {} for model_name, checkpoint_list in self.checkpoints.items(): data[model_name] = [cp.to_dict() for cp in checkpoint_list] with open(self.metadata_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}") def get_checkpoint_stats(self): """Get statistics about managed checkpoints""" stats = { 'total_models': len(self.checkpoints), 'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()), 'total_size_mb': 0.0, 'models': {} } for model_name, checkpoint_list in self.checkpoints.items(): if not checkpoint_list: continue model_size = sum(cp.file_size_mb for cp in checkpoint_list) best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score) stats['models'][model_name] = { 'checkpoint_count': len(checkpoint_list), 'total_size_mb': model_size, 'best_performance': best_checkpoint.performance_score, 'best_checkpoint_id': best_checkpoint.checkpoint_id, 'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id } stats['total_size_mb'] += model_size return stats def _find_legacy_model(self, model_name: str) -> Optional[Path]: """Find legacy saved models based on model name patterns""" base_dir = Path(self.base_dir) # Define model name mappings and patterns for legacy files legacy_patterns = { 'dqn_agent': [ 'dqn_agent_best_policy.pt', 'enhanced_dqn_best_policy.pt', 'improved_dqn_agent_best_policy.pt', 'dqn_agent_final_policy.pt' ], 'enhanced_cnn': [ 'cnn_model_best.pt', 'optimized_short_term_model_best.pt', 'optimized_short_term_model_realtime_best.pt', 'optimized_short_term_model_ticks_best.pt' ], 'extrema_trainer': [ 'supervised_model_best.pt' ], 'cob_rl': [ 'best_rl_model.pth_policy.pt', 'rl_agent_best_policy.pt' ], 'decision': [ # Decision models might be in subdirectories, but let's check main dir too 'decision_best.pt', 'decision_model_best.pt', # Check for transformer models which might be used as decision models 'enhanced_dqn_best_policy.pt', 'improved_dqn_agent_best_policy.pt' ] } # Get patterns for this model name patterns = legacy_patterns.get(model_name, []) # Also try generic patterns based on model name patterns.extend([ f'{model_name}_best.pt', f'{model_name}_best_policy.pt', f'{model_name}_final.pt', f'{model_name}_final_policy.pt' ]) # Search for the model files for pattern in patterns: candidate_path = base_dir / pattern if candidate_path.exists(): logger.debug(f"Found legacy model file: {candidate_path}") return candidate_path # Also check subdirectories for subdir in base_dir.iterdir(): if subdir.is_dir() and subdir.name == model_name: for pattern in patterns: candidate_path = subdir / pattern if candidate_path.exists(): logger.debug(f"Found legacy model file in subdirectory: {candidate_path}") return candidate_path return None def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata: """Create metadata for legacy model files using only actual file information""" try: file_size_mb = file_path.stat().st_size / (1024 * 1024) created_time = datetime.fromtimestamp(file_path.stat().st_mtime) # NO SYNTHETIC DATA - use only actual file information return CheckpointMetadata( checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}", model_name=model_name, model_type=model_name, file_path=str(file_path), created_at=created_time, file_size_mb=file_size_mb, performance_score=0.0, # Unknown performance - use 0, not synthetic values accuracy=None, loss=None, val_accuracy=None, val_loss=None, reward=None, pnl=None, epoch=None, training_time_hours=None, total_parameters=None, wandb_run_id=None, wandb_artifact_name=None ) except Exception as e: logger.error(f"Error creating legacy metadata for {model_name}: {e}") # Return a basic metadata with minimal info - NO SYNTHETIC VALUES return CheckpointMetadata( checkpoint_id=f"legacy_{model_name}", model_name=model_name, model_type=model_name, file_path=str(file_path), created_at=datetime.now(), file_size_mb=0.0, performance_score=0.0 # Unknown - use 0, not synthetic ) _checkpoint_manager = None def get_checkpoint_manager() -> CheckpointManager: global _checkpoint_manager if _checkpoint_manager is None: _checkpoint_manager = CheckpointManager() return _checkpoint_manager def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: return get_checkpoint_manager().save_checkpoint( model, model_name, model_type, performance_metrics, training_metadata, force_save ) def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: return get_checkpoint_manager().load_best_checkpoint(model_name)