Files
gogo2/NN/training/model_manager.py
2025-09-09 00:45:49 +03:00

602 lines
23 KiB
Python

"""
Unified Model Management System for Trading Dashboard
CONSOLIDATED SYSTEM - All model management functionality in one place
This system provides:
- Automatic cleanup of old model checkpoints
- Best model tracking with performance metrics
- Configurable retention policies
- Startup model loading
- Performance-based model selection
- Robust model saving with multiple fallback strategies
- Checkpoint management with W&B integration
- Centralized storage using @checkpoints/ structure
"""
import os
import json
import shutil
import logging
import torch
import glob
import pickle
import hashlib
import random
import numpy as np
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, List, Tuple, Union
from collections import defaultdict
# W&B import (optional)
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
logger = logging.getLogger(__name__)
@dataclass
class ModelMetrics:
"""Enhanced performance metrics for model evaluation"""
accuracy: float = 0.0
profit_factor: float = 0.0
win_rate: float = 0.0
sharpe_ratio: float = 0.0
max_drawdown: float = 0.0
total_trades: int = 0
avg_trade_duration: float = 0.0
confidence_score: float = 0.0
# Additional metrics from checkpoint_manager
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
def get_composite_score(self) -> float:
"""Calculate composite performance score"""
# Weighted composite score
weights = {
'profit_factor': 0.25,
'sharpe_ratio': 0.2,
'win_rate': 0.15,
'accuracy': 0.15,
'confidence_score': 0.1,
'loss_penalty': 0.1, # New: penalize high loss
'val_penalty': 0.05 # New: penalize validation loss
}
# Normalize values to 0-1 range
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
normalized_win_rate = self.win_rate
normalized_accuracy = self.accuracy
normalized_confidence = self.confidence_score
# Loss penalty (lower loss = higher score)
loss_penalty = 1.0
if self.loss is not None and self.loss > 0:
loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty
# Validation penalty
val_penalty = 1.0
if self.val_loss is not None and self.val_loss > 0:
val_penalty = max(0.1, 1 / (1 + self.val_loss))
# Apply penalties for poor performance
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
score = (
weights['profit_factor'] * normalized_pf +
weights['sharpe_ratio'] * normalized_sharpe +
weights['win_rate'] * normalized_win_rate +
weights['accuracy'] * normalized_accuracy +
weights['confidence_score'] * normalized_confidence +
weights['loss_penalty'] * loss_penalty +
weights['val_penalty'] * val_penalty
) * drawdown_penalty
return min(max(score, 0), 1)
@dataclass
class ModelInfo:
"""Model information tracking"""
model_type: str # 'cnn', 'rl', 'transformer'
model_name: str
file_path: str
creation_time: datetime
last_updated: datetime
file_size_mb: float
metrics: ModelMetrics
training_episodes: int = 0
model_version: str = "1.0"
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization"""
data = asdict(self)
data['creation_time'] = self.creation_time.isoformat()
data['last_updated'] = self.last_updated.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
"""Create from dictionary"""
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
data['metrics'] = ModelMetrics(**data['metrics'])
return cls(**data)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class ModelManager:
"""Unified model management system with @checkpoints/ structure"""
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
self.base_dir = Path(base_dir)
self.config = config or self._get_default_config()
# Updated directory structure using @checkpoints/
self.checkpoints_dir = self.base_dir / "@checkpoints"
self.models_dir = self.checkpoints_dir / "models"
self.saved_dir = self.checkpoints_dir / "saved"
self.best_models_dir = self.checkpoints_dir / "best_models"
self.archive_dir = self.checkpoints_dir / "archive"
# Model type directories within @checkpoints/
self.model_dirs = {
'cnn': self.checkpoints_dir / "cnn",
'dqn': self.checkpoints_dir / "dqn",
'rl': self.checkpoints_dir / "rl",
'transformer': self.checkpoints_dir / "transformer",
'hybrid': self.checkpoints_dir / "hybrid"
}
# Legacy directories for backward compatibility
self.nn_models_dir = self.base_dir / "NN" / "models"
self.legacy_models_dir = self.base_dir / "models"
# Metadata and checkpoint management
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
# Initialize storage
self._initialize_directories()
self.metadata = self._load_metadata()
self.checkpoint_metadata = self._load_checkpoint_metadata()
logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'max_checkpoints_per_model': 5,
'cleanup_old_models': True,
'auto_archive': True,
'wandb_enabled': WANDB_AVAILABLE,
'checkpoint_retention_days': 30
}
def _initialize_directories(self):
"""Initialize directory structure"""
directories = [
self.checkpoints_dir,
self.models_dir,
self.saved_dir,
self.best_models_dir,
self.archive_dir
] + list(self.model_dirs.values())
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load model metadata"""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
"""Load checkpoint metadata"""
if self.checkpoint_metadata_file.exists():
try:
with open(self.checkpoint_metadata_file, 'r') as f:
data = json.load(f)
# Convert dict values back to CheckpointMetadata objects
result = {}
for key, checkpoints in data.items():
result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints]
return result
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
return defaultdict(list)
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with enhanced error handling and validation"""
try:
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
# Create checkpoint directory
checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Generate checkpoint filename
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
filename = f"{checkpoint_id}.pt"
filepath = checkpoint_dir / filename
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'checkpoint_id': checkpoint_id,
'model_name': model_name,
'model_type': model_type,
'performance_score': performance_score,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'created_at': datetime.now().isoformat(),
'version': '2.0'
}
torch.save(save_dict, filepath)
# Create checkpoint metadata
file_size_mb = filepath.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(filepath),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=performance_metrics.get('epoch'),
training_time_hours=performance_metrics.get('training_time_hours'),
total_parameters=performance_metrics.get('total_parameters')
)
# Store metadata
self.checkpoint_metadata[model_name].append(metadata)
self._save_checkpoint_metadata()
# Rotate checkpoints if needed
self._rotate_checkpoints(model_name)
# Upload to W&B if enabled
if self.config.get('wandb_enabled'):
self._upload_to_wandb(metadata)
logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score from metrics"""
# Simple weighted score - can be enhanced
weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1}
score = 0.0
for metric, weight in weights.items():
if metric in metrics:
score += metrics[metric] * weight
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Determine if checkpoint should be saved"""
existing_checkpoints = self.checkpoint_metadata.get(model_name, [])
if not existing_checkpoints:
return True
# Keep if better than worst checkpoint or if we have fewer than max
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
if len(existing_checkpoints) < max_checkpoints:
return True
worst_score = min(cp.performance_score for cp in existing_checkpoints)
return performance_score > worst_score
def _rotate_checkpoints(self, model_name: str):
"""Rotate checkpoints to maintain max count"""
checkpoints = self.checkpoint_metadata.get(model_name, [])
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
if len(checkpoints) <= max_checkpoints:
return
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x.performance_score, reverse=True)
# Remove excess checkpoints
to_remove = checkpoints[max_checkpoints:]
for checkpoint in to_remove:
try:
Path(checkpoint.file_path).unlink(missing_ok=True)
logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}")
# Update metadata
self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints]
self._save_checkpoint_metadata()
def _save_checkpoint_metadata(self):
"""Save checkpoint metadata to file"""
try:
data = {}
for model_name, checkpoints in self.checkpoint_metadata.items():
data[model_name] = [cp.to_dict() for cp in checkpoints]
with open(self.checkpoint_metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]:
"""Upload checkpoint to W&B"""
if not WANDB_AVAILABLE:
return None
try:
# This would be implemented based on your W&B workflow
logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}")
return None
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Load the best checkpoint for a model"""
try:
# First, try the unified registry
model_info = self.metadata['models'].get(model_name)
if model_info and Path(model_info['latest_path']).exists():
# Load from unified registry
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
return model_info['latest_path'], None
# Fallback to checkpoint metadata
checkpoints = self.checkpoint_metadata.get(model_name, [])
if not checkpoints:
logger.warning(f"No checkpoints found for {model_name}")
return None
# Get best checkpoint
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
if not Path(best_checkpoint.file_path).exists():
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
return None
return best_checkpoint.file_path, best_checkpoint
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
total_size = 0
file_count = 0
for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]:
if directory.exists():
for file_path in directory.rglob('*'):
if file_path.is_file():
total_size += file_path.stat().st_size
file_count += 1
return {
'total_size_mb': total_size / (1024 * 1024),
'file_count': file_count,
'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0
}
except Exception as e:
logger.error(f"Error getting storage stats: {e}")
return {'error': str(e)}
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)"""
try:
stats = {
'total_models': 0,
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Count files in different directories as "checkpoints"
checkpoint_dirs = [
self.checkpoints_dir / "cnn",
self.checkpoints_dir / "dqn",
self.checkpoints_dir / "rl",
self.checkpoints_dir / "transformer",
self.checkpoints_dir / "hybrid"
]
total_size = 0
total_files = 0
for checkpoint_dir in checkpoint_dirs:
if checkpoint_dir.exists():
model_files = list(checkpoint_dir.rglob('*.pt'))
if model_files:
model_name = checkpoint_dir.name
stats['total_models'] += 1
model_size = sum(f.stat().st_size for f in model_files)
stats['total_checkpoints'] += len(model_files)
stats['total_size_mb'] += model_size / (1024 * 1024)
total_size += model_size
total_files += len(model_files)
# Get the most recent file as "latest"
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
stats['models'][model_name] = {
'checkpoint_count': len(model_files),
'total_size_mb': model_size / (1024 * 1024),
'best_performance': 0.0, # Not tracked in unified system
'best_checkpoint_id': latest_file.name,
'latest_checkpoint': latest_file.name
}
# Also check saved models directory
if self.saved_dir.exists():
saved_files = list(self.saved_dir.rglob('*.pt'))
if saved_files:
stats['total_checkpoints'] += len(saved_files)
saved_size = sum(f.stat().st_size for f in saved_files)
stats['total_size_mb'] += saved_size / (1024 * 1024)
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_models': 0,
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {},
'error': str(e)
}
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
"""Get model performance leaderboard"""
try:
leaderboard = []
for model_name, model_info in self.metadata['models'].items():
if 'metrics' in model_info:
metrics = ModelMetrics(**model_info['metrics'])
leaderboard.append({
'model_name': model_name,
'model_type': model_info.get('model_type', 'unknown'),
'composite_score': metrics.get_composite_score(),
'accuracy': metrics.accuracy,
'profit_factor': metrics.profit_factor,
'win_rate': metrics.win_rate,
'last_updated': model_info.get('last_saved', 'unknown')
})
# Sort by composite score
leaderboard.sort(key=lambda x: x['composite_score'], reverse=True)
return leaderboard
except Exception as e:
logger.error(f"Error getting leaderboard: {e}")
return []
# ===== LEGACY COMPATIBILITY FUNCTIONS =====
def create_model_manager() -> ModelManager:
"""Create and return a ModelManager instance"""
return ModelManager()
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""Legacy compatibility function to save a model"""
manager = create_model_manager()
return manager.save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""Legacy compatibility function to load a model"""
manager = create_model_manager()
return manager.load_model(model_name, model_type, model_class)
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Legacy compatibility function to save a checkpoint"""
manager = create_model_manager()
return manager.save_checkpoint(model, model_name, model_type,
performance_metrics, training_metadata, force_save)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Legacy compatibility function to load the best checkpoint"""
manager = create_model_manager()
return manager.load_best_checkpoint(model_name)
# ===== EXAMPLE USAGE =====
if __name__ == "__main__":
# Example usage of the unified model manager
manager = create_model_manager()
print(f"ModelManager initialized at: {manager.checkpoints_dir}")
# Get storage stats
stats = manager.get_storage_stats()
print(f"Storage stats: {stats}")
# Get leaderboard
leaderboard = manager.get_model_leaderboard()
print(f"Models in leaderboard: {len(leaderboard)}")