#!/usr/bin/env python3 """ Checkpoint Management System for W&B Training """ import os import json import logging from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, asdict from collections import defaultdict import torch try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False logger = logging.getLogger(__name__) @dataclass class CheckpointMetadata: checkpoint_id: str model_name: str model_type: str file_path: str created_at: datetime file_size_mb: float performance_score: float accuracy: Optional[float] = None loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None wandb_run_id: Optional[str] = None wandb_artifact_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: data = asdict(self) data['created_at'] = self.created_at.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': data['created_at'] = datetime.fromisoformat(data['created_at']) return cls(**data) class CheckpointManager: def __init__(self, base_checkpoint_dir: str = "NN/models/saved", max_checkpoints_per_model: int = 5, metadata_file: str = "checkpoint_metadata.json", enable_wandb: bool = True): self.base_dir = Path(base_checkpoint_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.max_checkpoints = max_checkpoints_per_model self.metadata_file = self.base_dir / metadata_file self.enable_wandb = enable_wandb and WANDB_AVAILABLE self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list) self._load_metadata() logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: try: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') checkpoint_id = f"{model_name}_{timestamp}" model_dir = self.base_dir / model_name model_dir.mkdir(exist_ok=True) checkpoint_path = model_dir / f"{checkpoint_id}.pt" performance_score = self._calculate_performance_score(performance_metrics) if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.info(f"Skipping checkpoint save for {model_name} - performance not improved") return None success = self._save_model_file(model, checkpoint_path, model_type) if not success: return None file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024) metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=model_name, model_type=model_type, file_path=str(checkpoint_path), created_at=datetime.now(), file_size_mb=file_size_mb, performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), val_accuracy=performance_metrics.get('val_accuracy'), val_loss=performance_metrics.get('val_loss'), reward=performance_metrics.get('reward'), pnl=performance_metrics.get('pnl'), epoch=training_metadata.get('epoch') if training_metadata else None, training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, total_parameters=training_metadata.get('total_parameters') if training_metadata else None ) if self.enable_wandb and wandb.run is not None: artifact_name = self._upload_to_wandb(checkpoint_path, metadata) metadata.wandb_run_id = wandb.run.id metadata.wandb_artifact_name = artifact_name self.checkpoints[model_name].append(metadata) self._rotate_checkpoints(model_name) self._save_metadata() logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})") return metadata except Exception as e: logger.error(f"Error saving checkpoint for {model_name}: {e}") return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: try: if model_name not in self.checkpoints or not self.checkpoints[model_name]: logger.warning(f"No checkpoints found for model: {model_name}") return None best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score) if not Path(best_checkpoint.file_path).exists(): logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}") return None logger.info(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") return best_checkpoint.file_path, best_checkpoint except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: score = 0.0 if 'accuracy' in metrics: score += metrics['accuracy'] * 100 if 'val_accuracy' in metrics: score += metrics['val_accuracy'] * 100 if 'loss' in metrics: score += max(0, 10 - metrics['loss']) if 'val_loss' in metrics: score += max(0, 10 - metrics['val_loss']) if 'reward' in metrics: score += metrics['reward'] if 'pnl' in metrics: score += metrics['pnl'] if score == 0.0 and metrics: first_metric = next(iter(metrics.values())) score = first_metric if first_metric > 0 else 0.1 return max(score, 0.1) def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: if model_name not in self.checkpoints or not self.checkpoints[model_name]: return True if len(self.checkpoints[model_name]) < self.max_checkpoints: return True worst_score = min(cp.performance_score for cp in self.checkpoints[model_name]) return performance_score > worst_score def _save_model_file(self, model, file_path: Path, model_type: str) -> bool: try: if hasattr(model, 'state_dict'): torch.save({ 'model_state_dict': model.state_dict(), 'model_type': model_type, 'saved_at': datetime.now().isoformat() }, file_path) else: torch.save(model, file_path) return True except Exception as e: logger.error(f"Error saving model file {file_path}: {e}") return False def _rotate_checkpoints(self, model_name: str): checkpoint_list = self.checkpoints[model_name] if len(checkpoint_list) <= self.max_checkpoints: return checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True) to_remove = checkpoint_list[self.max_checkpoints:] self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints] for checkpoint in to_remove: try: file_path = Path(checkpoint.file_path) if file_path.exists(): file_path.unlink() logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}") except Exception as e: logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}") def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]: try: if not self.enable_wandb or wandb.run is None: return None artifact_name = f"{metadata.model_name}_checkpoint" artifact = wandb.Artifact(artifact_name, type="model") artifact.add_file(str(file_path)) wandb.log_artifact(artifact) return artifact_name except Exception as e: logger.error(f"Error uploading to W&B: {e}") return None def _load_metadata(self): try: if self.metadata_file.exists(): with open(self.metadata_file, 'r') as f: data = json.load(f) for model_name, checkpoint_list in data.items(): self.checkpoints[model_name] = [ CheckpointMetadata.from_dict(cp_data) for cp_data in checkpoint_list ] logger.info(f"Loaded metadata for {len(self.checkpoints)} models") except Exception as e: logger.error(f"Error loading checkpoint metadata: {e}") def _save_metadata(self): try: data = {} for model_name, checkpoint_list in self.checkpoints.items(): data[model_name] = [cp.to_dict() for cp in checkpoint_list] with open(self.metadata_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}") def get_checkpoint_stats(self): """Get statistics about managed checkpoints""" stats = { 'total_models': len(self.checkpoints), 'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()), 'total_size_mb': 0.0, 'models': {} } for model_name, checkpoint_list in self.checkpoints.items(): if not checkpoint_list: continue model_size = sum(cp.file_size_mb for cp in checkpoint_list) best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score) stats['models'][model_name] = { 'checkpoint_count': len(checkpoint_list), 'total_size_mb': model_size, 'best_performance': best_checkpoint.performance_score, 'best_checkpoint_id': best_checkpoint.checkpoint_id, 'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id } stats['total_size_mb'] += model_size return stats _checkpoint_manager = None def get_checkpoint_manager() -> CheckpointManager: global _checkpoint_manager if _checkpoint_manager is None: _checkpoint_manager = CheckpointManager() return _checkpoint_manager def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: return get_checkpoint_manager().save_checkpoint( model, model_name, model_type, performance_metrics, training_metadata, force_save ) def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: return get_checkpoint_manager().load_best_checkpoint(model_name)