#!/usr/bin/env python3 """ Checkpoint Management System for W&B Training """ import os import json import logging from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, asdict from collections import defaultdict import torch import random try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False logger = logging.getLogger(__name__) @dataclass class CheckpointMetadata: checkpoint_id: str model_name: str model_type: str file_path: str created_at: datetime file_size_mb: float performance_score: float accuracy: Optional[float] = None loss: Optional[float] = None val_accuracy: Optional[float] = None val_loss: Optional[float] = None reward: Optional[float] = None pnl: Optional[float] = None epoch: Optional[int] = None training_time_hours: Optional[float] = None total_parameters: Optional[int] = None wandb_run_id: Optional[str] = None wandb_artifact_name: Optional[str] = None def to_dict(self) -> Dict[str, Any]: data = asdict(self) data['created_at'] = self.created_at.isoformat() return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': data['created_at'] = datetime.fromisoformat(data['created_at']) return cls(**data) class CheckpointManager: def __init__(self, base_checkpoint_dir: str = "NN/models/saved", max_checkpoints_per_model: int = 5, metadata_file: str = "checkpoint_metadata.json", enable_wandb: bool = True): self.base_dir = Path(base_checkpoint_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.max_checkpoints = max_checkpoints_per_model self.metadata_file = self.base_dir / metadata_file self.enable_wandb = enable_wandb and WANDB_AVAILABLE self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list) self._load_metadata() logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") def save_checkpoint(self, model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: try: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') checkpoint_id = f"{model_name}_{timestamp}" model_dir = self.base_dir / model_name model_dir.mkdir(exist_ok=True) checkpoint_path = model_dir / f"{checkpoint_id}.pt" performance_score = self._calculate_performance_score(performance_metrics) if not force_save and not self._should_save_checkpoint(model_name, performance_score): logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") return None success = self._save_model_file(model, checkpoint_path, model_type) if not success: return None file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024) metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=model_name, model_type=model_type, file_path=str(checkpoint_path), created_at=datetime.now(), file_size_mb=file_size_mb, performance_score=performance_score, accuracy=performance_metrics.get('accuracy'), loss=performance_metrics.get('loss'), val_accuracy=performance_metrics.get('val_accuracy'), val_loss=performance_metrics.get('val_loss'), reward=performance_metrics.get('reward'), pnl=performance_metrics.get('pnl'), epoch=training_metadata.get('epoch') if training_metadata else None, training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, total_parameters=training_metadata.get('total_parameters') if training_metadata else None ) if self.enable_wandb and wandb.run is not None: artifact_name = self._upload_to_wandb(checkpoint_path, metadata) metadata.wandb_run_id = wandb.run.id metadata.wandb_artifact_name = artifact_name self.checkpoints[model_name].append(metadata) self._rotate_checkpoints(model_name) self._save_metadata() logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})") return metadata except Exception as e: logger.error(f"Error saving checkpoint for {model_name}: {e}") return None def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: try: if model_name not in self.checkpoints or not self.checkpoints[model_name]: logger.warning(f"No checkpoints found for model: {model_name}") return None best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score) if not Path(best_checkpoint.file_path).exists(): # temporary disable logging to avoid spam # logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}") return None logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") return best_checkpoint.file_path, best_checkpoint except Exception as e: logger.error(f"Error loading best checkpoint for {model_name}: {e}") return None def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: """Calculate performance score with improved sensitivity for training models""" score = 0.0 # Prioritize loss reduction for active training models if 'loss' in metrics: # Invert loss so lower loss = higher score, with better scaling loss_value = metrics['loss'] if loss_value > 0: score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes else: score += 100 # Perfect loss # Add other metrics with appropriate weights if 'accuracy' in metrics: score += metrics['accuracy'] * 50 # Reduced weight to balance with loss if 'val_accuracy' in metrics: score += metrics['val_accuracy'] * 50 if 'val_loss' in metrics: val_loss = metrics['val_loss'] if val_loss > 0: score += max(0, 50 / (1 + val_loss)) if 'reward' in metrics: score += metrics['reward'] * 10 if 'pnl' in metrics: score += metrics['pnl'] * 5 if 'training_samples' in metrics: # Bonus for processing more training samples score += min(10, metrics['training_samples'] / 10) # Ensure minimum score for any training activity if score == 0.0 and metrics: # Use the first available metric with better scaling first_metric = next(iter(metrics.values())) if first_metric > 0: score = max(0.1, min(10, first_metric)) else: score = 0.1 return max(score, 0.1) def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: """Improved checkpoint saving logic with more frequent saves during training""" if model_name not in self.checkpoints or not self.checkpoints[model_name]: return True # Always save first checkpoint # Allow more checkpoints during active training if len(self.checkpoints[model_name]) < self.max_checkpoints: return True # Get current best and worst scores scores = [cp.performance_score for cp in self.checkpoints[model_name]] best_score = max(scores) worst_score = min(scores) # Save if better than worst (more frequent saves) if performance_score > worst_score: return True # For high-performing models (score > 100), be more sensitive to small improvements if best_score > 100: # Save if within 0.1% of best score (very sensitive for converged models) if performance_score >= best_score * 0.999: return True else: # Also save if we're within 10% of best score (capture near-optimal models) if performance_score >= best_score * 0.9: return True # Save more frequently during active training (every 5th attempt instead of 10th) if random.random() < 0.2: # 20% chance to save anyway logger.info(f"Saving checkpoint for {model_name} - periodic save during active training") return True return False def _save_model_file(self, model, file_path: Path, model_type: str) -> bool: try: if hasattr(model, 'state_dict'): torch.save({ 'model_state_dict': model.state_dict(), 'model_type': model_type, 'saved_at': datetime.now().isoformat() }, file_path) else: torch.save(model, file_path) return True except Exception as e: logger.error(f"Error saving model file {file_path}: {e}") return False def _rotate_checkpoints(self, model_name: str): checkpoint_list = self.checkpoints[model_name] if len(checkpoint_list) <= self.max_checkpoints: return checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True) to_remove = checkpoint_list[self.max_checkpoints:] self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints] for checkpoint in to_remove: try: file_path = Path(checkpoint.file_path) if file_path.exists(): file_path.unlink() logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}") except Exception as e: logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}") def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]: try: if not self.enable_wandb or wandb.run is None: return None artifact_name = f"{metadata.model_name}_checkpoint" artifact = wandb.Artifact(artifact_name, type="model") artifact.add_file(str(file_path)) wandb.log_artifact(artifact) return artifact_name except Exception as e: logger.error(f"Error uploading to W&B: {e}") return None def _load_metadata(self): try: if self.metadata_file.exists(): with open(self.metadata_file, 'r') as f: data = json.load(f) for model_name, checkpoint_list in data.items(): self.checkpoints[model_name] = [ CheckpointMetadata.from_dict(cp_data) for cp_data in checkpoint_list ] logger.info(f"Loaded metadata for {len(self.checkpoints)} models") except Exception as e: logger.error(f"Error loading checkpoint metadata: {e}") def _save_metadata(self): try: data = {} for model_name, checkpoint_list in self.checkpoints.items(): data[model_name] = [cp.to_dict() for cp in checkpoint_list] with open(self.metadata_file, 'w') as f: json.dump(data, f, indent=2) except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}") def get_checkpoint_stats(self): """Get statistics about managed checkpoints""" stats = { 'total_models': len(self.checkpoints), 'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()), 'total_size_mb': 0.0, 'models': {} } for model_name, checkpoint_list in self.checkpoints.items(): if not checkpoint_list: continue model_size = sum(cp.file_size_mb for cp in checkpoint_list) best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score) stats['models'][model_name] = { 'checkpoint_count': len(checkpoint_list), 'total_size_mb': model_size, 'best_performance': best_checkpoint.performance_score, 'best_checkpoint_id': best_checkpoint.checkpoint_id, 'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id } stats['total_size_mb'] += model_size return stats _checkpoint_manager = None def get_checkpoint_manager() -> CheckpointManager: global _checkpoint_manager if _checkpoint_manager is None: _checkpoint_manager = CheckpointManager() return _checkpoint_manager def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Optional[Dict[str, Any]] = None, force_save: bool = False) -> Optional[CheckpointMetadata]: return get_checkpoint_manager().save_checkpoint( model, model_name, model_type, performance_metrics, training_metadata, force_save ) def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: return get_checkpoint_manager().load_best_checkpoint(model_name)