Files
gogo2/utils/checkpoint_manager.py
2025-06-24 21:41:50 +03:00

307 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Checkpoint Management System for W&B Training
"""
import os
import json
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import torch
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
logger = logging.getLogger(__name__)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class CheckpointManager:
def __init__(self,
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = True):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
model_dir = self.base_dir / model_name
model_dir.mkdir(exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.info(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
success = self._save_model_file(model, checkpoint_path, model_type)
if not success:
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
if self.enable_wandb and wandb.run is not None:
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
metadata.wandb_run_id = wandb.run.id
metadata.wandb_artifact_name = artifact_name
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
logger.warning(f"No checkpoints found for model: {model_name}")
return None
best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score)
if not Path(best_checkpoint.file_path).exists():
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
return None
logger.info(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
score = 0.0
if 'accuracy' in metrics:
score += metrics['accuracy'] * 100
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 100
if 'loss' in metrics:
score += max(0, 10 - metrics['loss'])
if 'val_loss' in metrics:
score += max(0, 10 - metrics['val_loss'])
if 'reward' in metrics:
score += metrics['reward']
if 'pnl' in metrics:
score += metrics['pnl']
if score == 0.0 and metrics:
first_metric = next(iter(metrics.values()))
score = first_metric if first_metric > 0 else 0.1
return max(score, 0.1)
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
worst_score = min(cp.performance_score for cp in self.checkpoints[model_name])
return performance_score > worst_score
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
try:
if not self.enable_wandb or wandb.run is None:
return None
artifact_name = f"{metadata.model_name}_checkpoint"
artifact = wandb.Artifact(artifact_name, type="model")
artifact.add_file(str(file_path))
wandb.log_artifact(artifact)
return artifact_name
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
}
stats['total_size_mb'] += model_size
return stats
_checkpoint_manager = None
def get_checkpoint_manager() -> CheckpointManager:
global _checkpoint_manager
if _checkpoint_manager is None:
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
return get_checkpoint_manager().save_checkpoint(
model, model_name, model_type, performance_metrics, training_metadata, force_save
)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name)