checkpoint manager
This commit is contained in:
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Utils package for the multi-modal trading system
|
||||
"""
|
@ -1,466 +1,408 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Checkpoint Management System for W&B Training
|
||||
"""
|
||||
Checkpoint Manager
|
||||
|
||||
This module provides functionality for managing model checkpoints, including:
|
||||
- Saving checkpoints with metadata
|
||||
- Loading the best checkpoint based on performance metrics
|
||||
- Cleaning up old or underperforming checkpoints
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
import shutil
|
||||
import torch
|
||||
import random
|
||||
|
||||
try:
|
||||
import wandb
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
WANDB_AVAILABLE = False
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
file_path: str
|
||||
created_at: datetime
|
||||
file_size_mb: float
|
||||
performance_score: float
|
||||
accuracy: Optional[float] = None
|
||||
loss: Optional[float] = None
|
||||
val_accuracy: Optional[float] = None
|
||||
val_loss: Optional[float] = None
|
||||
reward: Optional[float] = None
|
||||
pnl: Optional[float] = None
|
||||
epoch: Optional[int] = None
|
||||
training_time_hours: Optional[float] = None
|
||||
total_parameters: Optional[int] = None
|
||||
wandb_run_id: Optional[str] = None
|
||||
wandb_artifact_name: Optional[str] = None
|
||||
# Global checkpoint manager instance
|
||||
_checkpoint_manager_instance = None
|
||||
|
||||
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
|
||||
"""
|
||||
Get the global checkpoint manager instance
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data['created_at'] = self.created_at.isoformat()
|
||||
return data
|
||||
Args:
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
max_checkpoints: Maximum number of checkpoints to keep
|
||||
metric_name: Metric to use for ranking checkpoints
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
|
||||
data['created_at'] = datetime.fromisoformat(data['created_at'])
|
||||
return cls(**data)
|
||||
Returns:
|
||||
CheckpointManager: Global checkpoint manager instance
|
||||
"""
|
||||
global _checkpoint_manager_instance
|
||||
|
||||
if _checkpoint_manager_instance is None:
|
||||
_checkpoint_manager_instance = CheckpointManager(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
max_checkpoints=max_checkpoints,
|
||||
metric_name=metric_name
|
||||
)
|
||||
|
||||
return _checkpoint_manager_instance
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
|
||||
"""
|
||||
Save a checkpoint with metadata
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
model_name: Name of the model
|
||||
model_type: Type of the model ('cnn', 'rl', etc.)
|
||||
performance_metrics: Performance metrics
|
||||
training_metadata: Additional training metadata
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
|
||||
Returns:
|
||||
Any: Checkpoint metadata
|
||||
"""
|
||||
try:
|
||||
# Create checkpoint directory
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create checkpoint path
|
||||
model_dir = os.path.join(checkpoint_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
|
||||
|
||||
# Save model
|
||||
if hasattr(model, 'save'):
|
||||
# Use model's save method if available
|
||||
model.save(checkpoint_path)
|
||||
else:
|
||||
# Otherwise, save state_dict
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp
|
||||
}, torch_path)
|
||||
|
||||
# Create metadata
|
||||
checkpoint_metadata = {
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp,
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata or {},
|
||||
'checkpoint_id': f"{model_name}_{timestamp}"
|
||||
}
|
||||
|
||||
# Add performance score for sorting
|
||||
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
|
||||
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
|
||||
checkpoint_metadata['created_at'] = timestamp
|
||||
|
||||
# Save metadata
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(checkpoint_metadata, f, indent=2)
|
||||
|
||||
# Get checkpoint manager and clean up old checkpoints
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_manager._cleanup_checkpoints(model_name)
|
||||
|
||||
# Return metadata as an object
|
||||
class CheckpointMetadata:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return CheckpointMetadata(checkpoint_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
|
||||
"""
|
||||
try:
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Convert metadata to object
|
||||
class CheckpointMetadata:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return None
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(self,
|
||||
base_checkpoint_dir: str = "NN/models/saved",
|
||||
max_checkpoints_per_model: int = 5,
|
||||
metadata_file: str = "checkpoint_metadata.json",
|
||||
enable_wandb: bool = True):
|
||||
self.base_dir = Path(base_checkpoint_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_checkpoints = max_checkpoints_per_model
|
||||
self.metadata_file = self.base_dir / metadata_file
|
||||
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
|
||||
|
||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||
self._load_metadata()
|
||||
|
||||
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
|
||||
"""
|
||||
Manages model checkpoints with performance-based optimization
|
||||
|
||||
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
This class:
|
||||
1. Saves checkpoints with metadata
|
||||
2. Loads the best checkpoint based on performance metrics
|
||||
3. Cleans up old or underperforming checkpoints
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
|
||||
"""
|
||||
Initialize the checkpoint manager
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
max_checkpoints: Maximum number of checkpoints to keep
|
||||
metric_name: Metric to use for ranking checkpoints
|
||||
"""
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.max_checkpoints = max_checkpoints
|
||||
self.metric_name = metric_name
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
|
||||
|
||||
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Save a checkpoint with metadata
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_path: Path to the model file
|
||||
metrics: Performance metrics
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
str: Path to the saved checkpoint
|
||||
"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_id = f"{model_name}_{timestamp}"
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
model_dir = self.base_dir / model_name
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
|
||||
# Create checkpoint path
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
|
||||
|
||||
performance_score = self._calculate_performance_score(performance_metrics)
|
||||
# Copy model file to checkpoint path
|
||||
shutil.copy2(model_path, f"{checkpoint_path}.pt")
|
||||
|
||||
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
||||
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||
return None
|
||||
|
||||
success = self._save_model_file(model, checkpoint_path, model_type)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
file_path=str(checkpoint_path),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=performance_score,
|
||||
accuracy=performance_metrics.get('accuracy'),
|
||||
loss=performance_metrics.get('loss'),
|
||||
val_accuracy=performance_metrics.get('val_accuracy'),
|
||||
val_loss=performance_metrics.get('val_loss'),
|
||||
reward=performance_metrics.get('reward'),
|
||||
pnl=performance_metrics.get('pnl'),
|
||||
epoch=training_metadata.get('epoch') if training_metadata else None,
|
||||
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
|
||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||
)
|
||||
|
||||
if self.enable_wandb and wandb.run is not None:
|
||||
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
|
||||
metadata.wandb_run_id = wandb.run.id
|
||||
metadata.wandb_artifact_name = artifact_name
|
||||
|
||||
self.checkpoints[model_name].append(metadata)
|
||||
self._rotate_checkpoints(model_name)
|
||||
self._save_metadata()
|
||||
|
||||
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
try:
|
||||
# First, try the standard checkpoint system
|
||||
if model_name in self.checkpoints and self.checkpoints[model_name]:
|
||||
# Filter out checkpoints with non-existent files
|
||||
valid_checkpoints = [
|
||||
cp for cp in self.checkpoints[model_name]
|
||||
if Path(cp.file_path).exists()
|
||||
]
|
||||
|
||||
if valid_checkpoints:
|
||||
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
|
||||
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
else:
|
||||
# Clean up invalid metadata entries
|
||||
invalid_count = len(self.checkpoints[model_name])
|
||||
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
|
||||
self.checkpoints[model_name] = []
|
||||
self._save_metadata()
|
||||
|
||||
# Fallback: Look for existing saved models in the legacy format
|
||||
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
|
||||
legacy_model_path = self._find_legacy_model(model_name)
|
||||
|
||||
if legacy_model_path:
|
||||
# Create checkpoint metadata for the legacy model using actual file data
|
||||
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
|
||||
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
|
||||
return str(legacy_model_path), legacy_metadata
|
||||
|
||||
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
|
||||
"""Calculate performance score with improved sensitivity for training models"""
|
||||
score = 0.0
|
||||
|
||||
# Prioritize loss reduction for active training models
|
||||
if 'loss' in metrics:
|
||||
# Invert loss so lower loss = higher score, with better scaling
|
||||
loss_value = metrics['loss']
|
||||
if loss_value > 0:
|
||||
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
|
||||
else:
|
||||
score += 100 # Perfect loss
|
||||
|
||||
# Add other metrics with appropriate weights
|
||||
if 'accuracy' in metrics:
|
||||
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
|
||||
if 'val_accuracy' in metrics:
|
||||
score += metrics['val_accuracy'] * 50
|
||||
if 'val_loss' in metrics:
|
||||
val_loss = metrics['val_loss']
|
||||
if val_loss > 0:
|
||||
score += max(0, 50 / (1 + val_loss))
|
||||
if 'reward' in metrics:
|
||||
score += metrics['reward'] * 10
|
||||
if 'pnl' in metrics:
|
||||
score += metrics['pnl'] * 5
|
||||
if 'training_samples' in metrics:
|
||||
# Bonus for processing more training samples
|
||||
score += min(10, metrics['training_samples'] / 10)
|
||||
|
||||
# Return actual calculated score - NO SYNTHETIC MINIMUM
|
||||
return score
|
||||
|
||||
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
||||
"""Improved checkpoint saving logic with more frequent saves during training"""
|
||||
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||
return True # Always save first checkpoint
|
||||
|
||||
# Allow more checkpoints during active training
|
||||
if len(self.checkpoints[model_name]) < self.max_checkpoints:
|
||||
return True
|
||||
|
||||
# Get current best and worst scores
|
||||
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
|
||||
best_score = max(scores)
|
||||
worst_score = min(scores)
|
||||
|
||||
# Save if better than worst (more frequent saves)
|
||||
if performance_score > worst_score:
|
||||
return True
|
||||
|
||||
# For high-performing models (score > 100), be more sensitive to small improvements
|
||||
if best_score > 100:
|
||||
# Save if within 0.1% of best score (very sensitive for converged models)
|
||||
if performance_score >= best_score * 0.999:
|
||||
return True
|
||||
else:
|
||||
# Also save if we're within 10% of best score (capture near-optimal models)
|
||||
if performance_score >= best_score * 0.9:
|
||||
return True
|
||||
|
||||
# Save more frequently during active training (every 5th attempt instead of 10th)
|
||||
if random.random() < 0.2: # 20% chance to save anyway
|
||||
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
|
||||
try:
|
||||
if hasattr(model, 'state_dict'):
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_type': model_type,
|
||||
'saved_at': datetime.now().isoformat()
|
||||
}, file_path)
|
||||
else:
|
||||
torch.save(model, file_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def _rotate_checkpoints(self, model_name: str):
|
||||
checkpoint_list = self.checkpoints[model_name]
|
||||
|
||||
if len(checkpoint_list) <= self.max_checkpoints:
|
||||
return
|
||||
|
||||
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
|
||||
|
||||
to_remove = checkpoint_list[self.max_checkpoints:]
|
||||
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
|
||||
|
||||
for checkpoint in to_remove:
|
||||
try:
|
||||
file_path = Path(checkpoint.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||
try:
|
||||
if not self.enable_wandb or wandb.run is None:
|
||||
return None
|
||||
|
||||
artifact_name = f"{metadata.model_name}_checkpoint"
|
||||
artifact = wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(str(file_path))
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
return artifact_name
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading to W&B: {e}")
|
||||
return None
|
||||
|
||||
def _load_metadata(self):
|
||||
try:
|
||||
if self.metadata_file.exists():
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
for model_name, checkpoint_list in data.items():
|
||||
self.checkpoints[model_name] = [
|
||||
CheckpointMetadata.from_dict(cp_data)
|
||||
for cp_data in checkpoint_list
|
||||
]
|
||||
|
||||
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata: {e}")
|
||||
|
||||
def _save_metadata(self):
|
||||
try:
|
||||
data = {}
|
||||
for model_name, checkpoint_list in self.checkpoints.items():
|
||||
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
|
||||
|
||||
with open(self.metadata_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}")
|
||||
|
||||
def get_checkpoint_stats(self):
|
||||
"""Get statistics about managed checkpoints"""
|
||||
stats = {
|
||||
'total_models': len(self.checkpoints),
|
||||
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
for model_name, checkpoint_list in self.checkpoints.items():
|
||||
if not checkpoint_list:
|
||||
continue
|
||||
|
||||
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
|
||||
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
|
||||
|
||||
stats['models'][model_name] = {
|
||||
'checkpoint_count': len(checkpoint_list),
|
||||
'total_size_mb': model_size,
|
||||
'best_performance': best_checkpoint.performance_score,
|
||||
'best_checkpoint_id': best_checkpoint.checkpoint_id,
|
||||
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
|
||||
# Create metadata
|
||||
checkpoint_metadata = {
|
||||
'model_name': model_name,
|
||||
'timestamp': timestamp,
|
||||
'metrics': metrics,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
stats['total_size_mb'] += model_size
|
||||
|
||||
return stats
|
||||
|
||||
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
|
||||
"""Find legacy saved models based on model name patterns"""
|
||||
base_dir = Path(self.base_dir)
|
||||
|
||||
# Define model name mappings and patterns for legacy files
|
||||
legacy_patterns = {
|
||||
'dqn_agent': [
|
||||
'dqn_agent_best_policy.pt',
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt',
|
||||
'dqn_agent_final_policy.pt'
|
||||
],
|
||||
'enhanced_cnn': [
|
||||
'cnn_model_best.pt',
|
||||
'optimized_short_term_model_best.pt',
|
||||
'optimized_short_term_model_realtime_best.pt',
|
||||
'optimized_short_term_model_ticks_best.pt'
|
||||
],
|
||||
'extrema_trainer': [
|
||||
'supervised_model_best.pt'
|
||||
],
|
||||
'cob_rl': [
|
||||
'best_rl_model.pth_policy.pt',
|
||||
'rl_agent_best_policy.pt'
|
||||
],
|
||||
'decision': [
|
||||
# Decision models might be in subdirectories, but let's check main dir too
|
||||
'decision_best.pt',
|
||||
'decision_model_best.pt',
|
||||
# Check for transformer models which might be used as decision models
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt'
|
||||
]
|
||||
}
|
||||
|
||||
# Get patterns for this model name
|
||||
patterns = legacy_patterns.get(model_name, [])
|
||||
|
||||
# Also try generic patterns based on model name
|
||||
patterns.extend([
|
||||
f'{model_name}_best.pt',
|
||||
f'{model_name}_best_policy.pt',
|
||||
f'{model_name}_final.pt',
|
||||
f'{model_name}_final_policy.pt'
|
||||
])
|
||||
|
||||
# Search for the model files
|
||||
for pattern in patterns:
|
||||
candidate_path = base_dir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.debug(f"Found legacy model file: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Also check subdirectories
|
||||
for subdir in base_dir.iterdir():
|
||||
if subdir.is_dir() and subdir.name == model_name:
|
||||
for pattern in patterns:
|
||||
candidate_path = subdir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
return None
|
||||
|
||||
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
||||
"""Create metadata for legacy model files using only actual file information"""
|
||||
try:
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
# Save metadata
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(checkpoint_metadata, f, indent=2)
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
||||
|
||||
# Clean up old checkpoints
|
||||
self._cleanup_checkpoints(model_name)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
# NO SYNTHETIC DATA - use only actual file information
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
|
||||
model_name=model_name,
|
||||
model_type=model_name,
|
||||
file_path=str(file_path),
|
||||
created_at=created_time,
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=0.0, # Unknown performance - use 0, not synthetic values
|
||||
accuracy=None,
|
||||
loss=None,
|
||||
val_accuracy=None,
|
||||
val_loss=None,
|
||||
reward=None,
|
||||
pnl=None,
|
||||
epoch=None,
|
||||
training_time_hours=None,
|
||||
total_parameters=None,
|
||||
wandb_run_id=None,
|
||||
wandb_artifact_name=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
|
||||
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}",
|
||||
model_name=model_name,
|
||||
model_type=model_name,
|
||||
file_path=str(file_path),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=0.0,
|
||||
performance_score=0.0 # Unknown - use 0, not synthetic
|
||||
)
|
||||
|
||||
_checkpoint_manager = None
|
||||
|
||||
def get_checkpoint_manager() -> CheckpointManager:
|
||||
global _checkpoint_manager
|
||||
if _checkpoint_manager is None:
|
||||
_checkpoint_manager = CheckpointManager()
|
||||
return _checkpoint_manager
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
return get_checkpoint_manager().save_checkpoint(
|
||||
model, model_name, model_type, performance_metrics, training_metadata, force_save
|
||||
)
|
||||
|
||||
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
return get_checkpoint_manager().load_best_checkpoint(model_name)
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
return ""
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
|
||||
"""
|
||||
try:
|
||||
# Find all checkpoint metadata files
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
logger.info(f"No checkpoints found for {model_name}")
|
||||
return "", {}
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
checkpoints = []
|
||||
for metadata_file in metadata_files:
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Get checkpoint path (remove _metadata.json)
|
||||
checkpoint_path = metadata_file[:-14]
|
||||
|
||||
# Check if model file exists
|
||||
if not os.path.exists(f"{checkpoint_path}.pt"):
|
||||
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
|
||||
continue
|
||||
|
||||
checkpoints.append((checkpoint_path, metadata))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
logger.info(f"No valid checkpoints found for {model_name}")
|
||||
return "", {}
|
||||
|
||||
# Sort by metric (highest first)
|
||||
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
|
||||
|
||||
# Return best checkpoint
|
||||
best_checkpoint_path = checkpoints[0][0]
|
||||
best_checkpoint_metadata = checkpoints[0][1]
|
||||
|
||||
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
|
||||
|
||||
return best_checkpoint_path, best_checkpoint_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return "", {}
|
||||
|
||||
def _cleanup_checkpoints(self, model_name: str) -> int:
|
||||
"""
|
||||
Clean up old or underperforming checkpoints
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
int: Number of checkpoints deleted
|
||||
"""
|
||||
try:
|
||||
# Find all checkpoint metadata files
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
|
||||
return 0
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
checkpoints = []
|
||||
for metadata_file in metadata_files:
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Get checkpoint path (remove _metadata.json)
|
||||
checkpoint_path = metadata_file[:-14]
|
||||
|
||||
checkpoints.append((checkpoint_path, metadata))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
# Sort by metric (highest first)
|
||||
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
|
||||
|
||||
# Keep only the best checkpoints
|
||||
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
|
||||
|
||||
# Delete checkpoints
|
||||
deleted_count = 0
|
||||
for checkpoint_path, _ in checkpoints_to_delete:
|
||||
try:
|
||||
# Delete model file
|
||||
if os.path.exists(f"{checkpoint_path}.pt"):
|
||||
os.remove(f"{checkpoint_path}.pt")
|
||||
|
||||
# Delete metadata file
|
||||
if os.path.exists(f"{checkpoint_path}_metadata.json"):
|
||||
os.remove(f"{checkpoint_path}_metadata.json")
|
||||
|
||||
deleted_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
|
||||
|
||||
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up checkpoints: {e}")
|
||||
return 0
|
||||
|
||||
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
"""
|
||||
Get all checkpoints for a model
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
|
||||
"""
|
||||
try:
|
||||
# Find all checkpoint metadata files
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
return []
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
checkpoints = []
|
||||
for metadata_file in metadata_files:
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Get checkpoint path (remove _metadata.json)
|
||||
checkpoint_path = metadata_file[:-14]
|
||||
|
||||
# Check if model file exists
|
||||
if not os.path.exists(f"{checkpoint_path}.pt"):
|
||||
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
|
||||
continue
|
||||
|
||||
checkpoints.append((checkpoint_path, metadata))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
|
||||
|
||||
return checkpoints
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all checkpoints: {e}")
|
||||
return []
|
@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
|
||||
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -78,7 +78,7 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
metadata = save_checkpoint(
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
@ -137,7 +137,7 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
metadata = save_checkpoint(
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_type='rl',
|
||||
@ -158,7 +158,7 @@ class TrainingIntegration:
|
||||
|
||||
def load_best_model(self, model_name: str, model_class=None):
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
result = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
if not result:
|
||||
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||
return None
|
||||
|
Reference in New Issue
Block a user