refactoring

This commit is contained in:
Dobromir Popov
2025-09-08 23:57:21 +03:00
parent 98ebbe5089
commit c3a94600c8
50 changed files with 856 additions and 1302 deletions

View File

@@ -1,5 +1,7 @@
"""
Enhanced Model Management System for Trading Dashboard
Unified Model Management System for Trading Dashboard
CONSOLIDATED SYSTEM - All model management functionality in one place
This system provides:
- Automatic cleanup of old model checkpoints
@@ -7,6 +9,9 @@ This system provides:
- Configurable retention policies
- Startup model loading
- Performance-based model selection
- Robust model saving with multiple fallback strategies
- Checkpoint management with W&B integration
- Centralized storage using @checkpoints/ structure
"""
import os
@@ -15,17 +20,30 @@ import shutil
import logging
import torch
import glob
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from pathlib import Path
import pickle
import hashlib
import random
import numpy as np
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, List, Tuple, Union
from collections import defaultdict
# W&B import (optional)
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
logger = logging.getLogger(__name__)
@dataclass
class ModelMetrics:
"""Performance metrics for model evaluation"""
"""Enhanced performance metrics for model evaluation"""
accuracy: float = 0.0
profit_factor: float = 0.0
win_rate: float = 0.0
@@ -34,41 +52,66 @@ class ModelMetrics:
total_trades: int = 0
avg_trade_duration: float = 0.0
confidence_score: float = 0.0
# Additional metrics from checkpoint_manager
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
def get_composite_score(self) -> float:
"""Calculate composite performance score"""
# Weighted composite score
weights = {
'profit_factor': 0.3,
'sharpe_ratio': 0.25,
'win_rate': 0.2,
'profit_factor': 0.25,
'sharpe_ratio': 0.2,
'win_rate': 0.15,
'accuracy': 0.15,
'confidence_score': 0.1
'confidence_score': 0.1,
'loss_penalty': 0.1, # New: penalize high loss
'val_penalty': 0.05 # New: penalize validation loss
}
# Normalize values to 0-1 range
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
normalized_win_rate = self.win_rate
normalized_accuracy = self.accuracy
normalized_confidence = self.confidence_score
# Loss penalty (lower loss = higher score)
loss_penalty = 1.0
if self.loss is not None and self.loss > 0:
loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty
# Validation penalty
val_penalty = 1.0
if self.val_loss is not None and self.val_loss > 0:
val_penalty = max(0.1, 1 / (1 + self.val_loss))
# Apply penalties for poor performance
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
score = (
weights['profit_factor'] * normalized_pf +
weights['sharpe_ratio'] * normalized_sharpe +
weights['win_rate'] * normalized_win_rate +
weights['accuracy'] * normalized_accuracy +
weights['confidence_score'] * normalized_confidence
weights['confidence_score'] * normalized_confidence +
weights['loss_penalty'] * loss_penalty +
weights['val_penalty'] * val_penalty
) * drawdown_penalty
return min(max(score, 0), 1)
@dataclass
class ModelInfo:
"""Complete model information and metadata"""
"""Model information tracking"""
model_type: str # 'cnn', 'rl', 'transformer'
model_name: str
file_path: str
@@ -78,14 +121,14 @@ class ModelInfo:
metrics: ModelMetrics
training_episodes: int = 0
model_version: str = "1.0"
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization"""
data = asdict(self)
data['creation_time'] = self.creation_time.isoformat()
data['last_updated'] = self.last_updated.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
"""Create from dictionary"""
@@ -94,465 +137,400 @@ class ModelInfo:
data['metrics'] = ModelMetrics(**data['metrics'])
return cls(**data)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class ModelManager:
"""Enhanced model management system"""
"""Unified model management system with @checkpoints/ structure"""
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
self.base_dir = Path(base_dir)
self.config = config or self._get_default_config()
# Model directories
self.models_dir = self.base_dir / "models"
# Updated directory structure using @checkpoints/
self.checkpoints_dir = self.base_dir / "@checkpoints"
self.models_dir = self.checkpoints_dir / "models"
self.saved_dir = self.checkpoints_dir / "saved"
self.best_models_dir = self.checkpoints_dir / "best_models"
self.archive_dir = self.checkpoints_dir / "archive"
# Model type directories within @checkpoints/
self.model_dirs = {
'cnn': self.checkpoints_dir / "cnn",
'dqn': self.checkpoints_dir / "dqn",
'rl': self.checkpoints_dir / "rl",
'transformer': self.checkpoints_dir / "transformer",
'hybrid': self.checkpoints_dir / "hybrid"
}
# Legacy directories for backward compatibility
self.nn_models_dir = self.base_dir / "NN" / "models"
self.registry_file = self.models_dir / "model_registry.json"
self.best_models_dir = self.models_dir / "best_models"
# Create directories
self.best_models_dir.mkdir(parents=True, exist_ok=True)
# Model registry
self.model_registry: Dict[str, ModelInfo] = {}
self._load_registry()
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
self.legacy_models_dir = self.base_dir / "models"
# Metadata and checkpoint management
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
# Initialize storage
self._initialize_directories()
self.metadata = self._load_metadata()
self.checkpoint_metadata = self._load_checkpoint_metadata()
logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'max_models_per_type': 3, # Keep top 3 models per type
'max_total_models': 10, # Maximum total models to keep
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
'min_performance_threshold': 0.3, # Minimum composite score
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
'auto_cleanup_enabled': True,
'backup_before_cleanup': True,
'model_size_limit_mb': 100, # Individual model size limit
'total_storage_limit_gb': 5.0 # Total storage limit
'max_checkpoints_per_model': 5,
'cleanup_old_models': True,
'auto_archive': True,
'wandb_enabled': WANDB_AVAILABLE,
'checkpoint_retention_days': 30
}
def _load_registry(self):
"""Load model registry from file"""
try:
if self.registry_file.exists():
with open(self.registry_file, 'r') as f:
data = json.load(f)
self.model_registry = {
k: ModelInfo.from_dict(v) for k, v in data.items()
}
logger.info(f"Loaded {len(self.model_registry)} models from registry")
else:
logger.info("No existing model registry found")
except Exception as e:
logger.error(f"Error loading model registry: {e}")
self.model_registry = {}
def _save_registry(self):
"""Save model registry to file"""
try:
self.models_dir.mkdir(parents=True, exist_ok=True)
with open(self.registry_file, 'w') as f:
data = {k: v.to_dict() for k, v in self.model_registry.items()}
json.dump(data, f, indent=2, default=str)
logger.info(f"Saved registry with {len(self.model_registry)} models")
except Exception as e:
logger.error(f"Error saving model registry: {e}")
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
"""
Clean up all existing model files and prepare for 2-action system training
Args:
confirm: If True, perform the cleanup. If False, return what would be cleaned
Returns:
Dict with cleanup statistics
"""
cleanup_stats = {
'files_found': 0,
'files_deleted': 0,
'directories_cleaned': 0,
'space_freed_mb': 0.0,
'errors': []
}
# Model file patterns for both 2-action and legacy 3-action systems
model_patterns = [
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
]
# Directories to clean
model_directories = [
"models/saved",
"NN/models/saved",
"NN/models/saved/checkpoints",
"NN/models/saved/realtime_checkpoints",
"NN/models/saved/realtime_ticks_checkpoints",
"model_backups"
]
try:
# Scan for files to be cleaned
for directory in model_directories:
dir_path = Path(self.base_dir) / directory
if dir_path.exists():
for pattern in model_patterns:
for file_path in dir_path.glob(pattern):
if file_path.is_file():
cleanup_stats['files_found'] += 1
file_size = file_path.stat().st_size / (1024 * 1024) # MB
cleanup_stats['space_freed_mb'] += file_size
if confirm:
try:
file_path.unlink()
cleanup_stats['files_deleted'] += 1
logger.info(f"Deleted model file: {file_path}")
except Exception as e:
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
# Clean up empty checkpoint directories
for directory in model_directories:
dir_path = Path(self.base_dir) / directory
if dir_path.exists():
for subdir in dir_path.rglob("*"):
if subdir.is_dir() and not any(subdir.iterdir()):
if confirm:
try:
subdir.rmdir()
cleanup_stats['directories_cleaned'] += 1
logger.info(f"Removed empty directory: {subdir}")
except Exception as e:
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
if confirm:
# Clear the registry for fresh start with 2-action system
self.model_registry = {
'models': {},
'metadata': {
'last_updated': datetime.now().isoformat(),
'total_models': 0,
'system_type': '2_action', # Mark as 2-action system
'action_space': ['SELL', 'BUY'],
'version': '2.0'
}
}
self._save_registry()
logger.info("=" * 60)
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
logger.info("Registry reset for 2-action system (BUY/SELL)")
logger.info("Ready for fresh training with intelligent position management")
logger.info("=" * 60)
else:
logger.info("=" * 60)
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
logger.info("Run with confirm=True to perform cleanup")
logger.info("=" * 60)
except Exception as e:
cleanup_stats['errors'].append(f"Cleanup error: {e}")
logger.error(f"Error during model cleanup: {e}")
return cleanup_stats
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
"""
Register a new model in the 2-action system
Args:
model_path: Path to the model file
model_type: Type of model ('cnn', 'rl', 'transformer')
metrics: Performance metrics
Returns:
str: Unique model name/ID
"""
if not Path(model_path).exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Generate unique model name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"{model_type}_2action_{timestamp}"
# Get file info
file_path = Path(model_path)
file_size_mb = file_path.stat().st_size / (1024 * 1024)
# Default metrics for 2-action system
if metrics is None:
metrics = ModelMetrics(
accuracy=0.0,
profit_factor=1.0,
win_rate=0.5,
sharpe_ratio=0.0,
max_drawdown=0.0,
confidence_score=0.5
)
# Create model info
model_info = ModelInfo(
model_type=model_type,
model_name=model_name,
file_path=str(file_path.absolute()),
creation_time=datetime.now(),
last_updated=datetime.now(),
file_size_mb=file_size_mb,
metrics=metrics,
model_version="2.0" # 2-action system version
)
# Add to registry
self.model_registry['models'][model_name] = model_info.to_dict()
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
self.model_registry['metadata']['system_type'] = '2_action'
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
self._save_registry()
# Cleanup old models if necessary
self._cleanup_models_by_type(model_type)
logger.info(f"Registered 2-action model: {model_name}")
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
return model_name
def _should_keep_model(self, model_info: ModelInfo) -> bool:
"""Determine if model should be kept based on performance"""
score = model_info.metrics.get_composite_score()
# Check minimum threshold
if score < self.config['min_performance_threshold']:
return False
# Check size limit
if model_info.file_size_mb > self.config['model_size_limit_mb']:
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
return False
# Check if better than existing models of same type
existing_models = self.get_models_by_type(model_info.model_type)
if len(existing_models) >= self.config['max_models_per_type']:
# Find worst performing model
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
if score <= worst_model.metrics.get_composite_score():
return False
return True
def _cleanup_models_by_type(self, model_type: str):
"""Cleanup old models of specific type, keeping only the best ones"""
models_of_type = self.get_models_by_type(model_type)
max_keep = self.config['max_models_per_type']
if len(models_of_type) <= max_keep:
return
# Sort by performance score
sorted_models = sorted(
models_of_type.items(),
key=lambda x: x[1].metrics.get_composite_score(),
reverse=True
)
# Keep only the best models
models_to_keep = sorted_models[:max_keep]
models_to_remove = sorted_models[max_keep:]
for model_name, model_info in models_to_remove:
def _initialize_directories(self):
"""Initialize directory structure"""
directories = [
self.checkpoints_dir,
self.models_dir,
self.saved_dir,
self.best_models_dir,
self.archive_dir
] + list(self.model_dirs.values())
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load model metadata"""
if self.metadata_file.exists():
try:
# Remove file
model_path = Path(model_info.file_path)
if model_path.exists():
model_path.unlink()
# Remove from registry
del self.model_registry[model_name]
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error removing model {model_name}: {e}")
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
"""Get all models of a specific type"""
return {
name: info for name, info in self.model_registry.items()
if info.model_type == model_type
}
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
"""Get the best performing model of a specific type"""
models_of_type = self.get_models_by_type(model_type)
if not models_of_type:
return None
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
def load_best_models(self) -> Dict[str, Any]:
"""Load the best models for each type"""
loaded_models = {}
for model_type in ['cnn', 'rl', 'transformer']:
best_model = self.get_best_model(model_type)
if best_model:
try:
model_path = Path(best_model.file_path)
if model_path.exists():
# Load the model
model_data = torch.load(model_path, map_location='cpu')
loaded_models[model_type] = {
'model': model_data,
'info': best_model,
'path': str(model_path)
}
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
f"(Score: {best_model.metrics.get_composite_score():.3f})")
else:
logger.warning(f"Best {model_type} model file not found: {model_path}")
except Exception as e:
logger.error(f"Error loading {model_type} model: {e}")
else:
logger.info(f"No {model_type} model available")
return loaded_models
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
"""Update performance metrics for a model"""
if model_name in self.model_registry:
self.model_registry[model_name].metrics = metrics
self.model_registry[model_name].last_updated = datetime.now()
self._save_registry()
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
else:
logger.warning(f"Model {model_name} not found in registry")
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage usage statistics"""
total_size_mb = 0
model_count = 0
for model_info in self.model_registry.values():
total_size_mb += model_info.file_size_mb
model_count += 1
# Check actual storage usage
actual_size_mb = 0
if self.best_models_dir.exists():
actual_size_mb = sum(
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
) / 1024 / 1024
return {
'total_models': model_count,
'registered_size_mb': total_size_mb,
'actual_size_mb': actual_size_mb,
'storage_limit_gb': self.config['total_storage_limit_gb'],
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
'models_by_type': {
model_type: len(self.get_models_by_type(model_type))
for model_type in ['cnn', 'rl', 'transformer']
logger.error(f"Error loading metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
"""Load checkpoint metadata"""
if self.checkpoint_metadata_file.exists():
try:
with open(self.checkpoint_metadata_file, 'r') as f:
data = json.load(f)
# Convert dict values back to CheckpointMetadata objects
result = {}
for key, checkpoints in data.items():
result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints]
return result
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
return defaultdict(list)
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with enhanced error handling and validation"""
try:
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
# Create checkpoint directory
checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Generate checkpoint filename
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
filename = f"{checkpoint_id}.pt"
filepath = checkpoint_dir / filename
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'checkpoint_id': checkpoint_id,
'model_name': model_name,
'model_type': model_type,
'performance_score': performance_score,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'created_at': datetime.now().isoformat(),
'version': '2.0'
}
}
torch.save(save_dict, filepath)
# Create checkpoint metadata
file_size_mb = filepath.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(filepath),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=performance_metrics.get('epoch'),
training_time_hours=performance_metrics.get('training_time_hours'),
total_parameters=performance_metrics.get('total_parameters')
)
# Store metadata
self.checkpoint_metadata[model_name].append(metadata)
self._save_checkpoint_metadata()
# Rotate checkpoints if needed
self._rotate_checkpoints(model_name)
# Upload to W&B if enabled
if self.config.get('wandb_enabled'):
self._upload_to_wandb(metadata)
logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score from metrics"""
# Simple weighted score - can be enhanced
weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1}
score = 0.0
for metric, weight in weights.items():
if metric in metrics:
score += metrics[metric] * weight
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Determine if checkpoint should be saved"""
existing_checkpoints = self.checkpoint_metadata.get(model_name, [])
if not existing_checkpoints:
return True
# Keep if better than worst checkpoint or if we have fewer than max
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
if len(existing_checkpoints) < max_checkpoints:
return True
worst_score = min(cp.performance_score for cp in existing_checkpoints)
return performance_score > worst_score
def _rotate_checkpoints(self, model_name: str):
"""Rotate checkpoints to maintain max count"""
checkpoints = self.checkpoint_metadata.get(model_name, [])
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
if len(checkpoints) <= max_checkpoints:
return
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x.performance_score, reverse=True)
# Remove excess checkpoints
to_remove = checkpoints[max_checkpoints:]
for checkpoint in to_remove:
try:
Path(checkpoint.file_path).unlink(missing_ok=True)
logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}")
# Update metadata
self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints]
self._save_checkpoint_metadata()
def _save_checkpoint_metadata(self):
"""Save checkpoint metadata to file"""
try:
data = {}
for model_name, checkpoints in self.checkpoint_metadata.items():
data[model_name] = [cp.to_dict() for cp in checkpoints]
with open(self.checkpoint_metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]:
"""Upload checkpoint to W&B"""
if not WANDB_AVAILABLE:
return None
try:
# This would be implemented based on your W&B workflow
logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}")
return None
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Load the best checkpoint for a model"""
try:
# First, try the unified registry
model_info = self.metadata['models'].get(model_name)
if model_info and Path(model_info['latest_path']).exists():
# Load from unified registry
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
return model_info['latest_path'], None
# Fallback to checkpoint metadata
checkpoints = self.checkpoint_metadata.get(model_name, [])
if not checkpoints:
logger.warning(f"No checkpoints found for {model_name}")
return None
# Get best checkpoint
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
if not Path(best_checkpoint.file_path).exists():
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
return None
return best_checkpoint.file_path, best_checkpoint
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
total_size = 0
file_count = 0
for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]:
if directory.exists():
for file_path in directory.rglob('*'):
if file_path.is_file():
total_size += file_path.stat().st_size
file_count += 1
return {
'total_size_mb': total_size / (1024 * 1024),
'file_count': file_count,
'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0
}
except Exception as e:
logger.error(f"Error getting storage stats: {e}")
return {'error': str(e)}
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
"""Get model performance leaderboard"""
leaderboard = []
for model_name, model_info in self.model_registry.items():
leaderboard.append({
'name': model_name,
'type': model_info.model_type,
'score': model_info.metrics.get_composite_score(),
'profit_factor': model_info.metrics.profit_factor,
'win_rate': model_info.metrics.win_rate,
'sharpe_ratio': model_info.metrics.sharpe_ratio,
'size_mb': model_info.file_size_mb,
'age_days': (datetime.now() - model_info.creation_time).days,
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
})
# Sort by score
leaderboard.sort(key=lambda x: x['score'], reverse=True)
return leaderboard
def cleanup_checkpoints(self) -> Dict[str, Any]:
"""Clean up old checkpoint files"""
cleanup_summary = {
'deleted_files': 0,
'freed_space_mb': 0,
'errors': []
}
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
# Search for checkpoint files
checkpoint_patterns = [
"**/checkpoint_*.pt",
"**/model_*.pt",
"**/*checkpoint*",
"**/epoch_*.pt"
]
for pattern in checkpoint_patterns:
for file_path in self.base_dir.rglob(pattern):
if "best_models" not in str(file_path) and file_path.is_file():
try:
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
if file_time < cutoff_date:
size_mb = file_path.stat().st_size / 1024 / 1024
file_path.unlink()
cleanup_summary['deleted_files'] += 1
cleanup_summary['freed_space_mb'] += size_mb
except Exception as e:
error_msg = f"Error deleting checkpoint {file_path}: {e}"
logger.error(error_msg)
cleanup_summary['errors'].append(error_msg)
if cleanup_summary['deleted_files'] > 0:
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
return cleanup_summary
try:
leaderboard = []
for model_name, model_info in self.metadata['models'].items():
if 'metrics' in model_info:
metrics = ModelMetrics(**model_info['metrics'])
leaderboard.append({
'model_name': model_name,
'model_type': model_info.get('model_type', 'unknown'),
'composite_score': metrics.get_composite_score(),
'accuracy': metrics.accuracy,
'profit_factor': metrics.profit_factor,
'win_rate': metrics.win_rate,
'last_updated': model_info.get('last_saved', 'unknown')
})
# Sort by composite score
leaderboard.sort(key=lambda x: x['composite_score'], reverse=True)
return leaderboard
except Exception as e:
logger.error(f"Error getting leaderboard: {e}")
return []
# ===== LEGACY COMPATIBILITY FUNCTIONS =====
def create_model_manager() -> ModelManager:
"""Create and initialize the global model manager"""
"""Create and return a ModelManager instance"""
return ModelManager()
# Example usage
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""Legacy compatibility function to save a model"""
manager = create_model_manager()
return manager.save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""Legacy compatibility function to load a model"""
manager = create_model_manager()
return manager.load_model(model_name, model_type, model_class)
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Legacy compatibility function to save a checkpoint"""
manager = create_model_manager()
return manager.save_checkpoint(model, model_name, model_type,
performance_metrics, training_metadata, force_save)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Legacy compatibility function to load the best checkpoint"""
manager = create_model_manager()
return manager.load_best_checkpoint(model_name)
# ===== EXAMPLE USAGE =====
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO)
# Create model manager
manager = ModelManager()
# Clean up all existing models (with confirmation)
print("WARNING: This will delete ALL existing models!")
print("Type 'CONFIRM' to proceed:")
user_input = input().strip()
if user_input == "CONFIRM":
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
print(f"\nCleanup complete:")
print(f"- Deleted {cleanup_result['files_deleted']} files")
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
if cleanup_result['errors']:
print(f"- {len(cleanup_result['errors'])} errors occurred")
else:
print("Cleanup cancelled")
# Example usage of the unified model manager
manager = create_model_manager()
print(f"ModelManager initialized at: {manager.checkpoints_dir}")
# Get storage stats
stats = manager.get_storage_stats()
print(f"Storage stats: {stats}")
# Get leaderboard
leaderboard = manager.get_model_leaderboard()
print(f"Models in leaderboard: {len(leaderboard)}")