558 lines
22 KiB
Python
558 lines
22 KiB
Python
"""
|
|
Enhanced Model Management System for Trading Dashboard
|
|
|
|
This system provides:
|
|
- Automatic cleanup of old model checkpoints
|
|
- Best model tracking with performance metrics
|
|
- Configurable retention policies
|
|
- Startup model loading
|
|
- Performance-based model selection
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import shutil
|
|
import logging
|
|
import torch
|
|
import glob
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class ModelMetrics:
|
|
"""Performance metrics for model evaluation"""
|
|
accuracy: float = 0.0
|
|
profit_factor: float = 0.0
|
|
win_rate: float = 0.0
|
|
sharpe_ratio: float = 0.0
|
|
max_drawdown: float = 0.0
|
|
total_trades: int = 0
|
|
avg_trade_duration: float = 0.0
|
|
confidence_score: float = 0.0
|
|
|
|
def get_composite_score(self) -> float:
|
|
"""Calculate composite performance score"""
|
|
# Weighted composite score
|
|
weights = {
|
|
'profit_factor': 0.3,
|
|
'sharpe_ratio': 0.25,
|
|
'win_rate': 0.2,
|
|
'accuracy': 0.15,
|
|
'confidence_score': 0.1
|
|
}
|
|
|
|
# Normalize values to 0-1 range
|
|
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
|
|
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
|
|
normalized_win_rate = self.win_rate
|
|
normalized_accuracy = self.accuracy
|
|
normalized_confidence = self.confidence_score
|
|
|
|
# Apply penalties for poor performance
|
|
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
|
|
|
|
score = (
|
|
weights['profit_factor'] * normalized_pf +
|
|
weights['sharpe_ratio'] * normalized_sharpe +
|
|
weights['win_rate'] * normalized_win_rate +
|
|
weights['accuracy'] * normalized_accuracy +
|
|
weights['confidence_score'] * normalized_confidence
|
|
) * drawdown_penalty
|
|
|
|
return min(max(score, 0), 1)
|
|
|
|
@dataclass
|
|
class ModelInfo:
|
|
"""Complete model information and metadata"""
|
|
model_type: str # 'cnn', 'rl', 'transformer'
|
|
model_name: str
|
|
file_path: str
|
|
creation_time: datetime
|
|
last_updated: datetime
|
|
file_size_mb: float
|
|
metrics: ModelMetrics
|
|
training_episodes: int = 0
|
|
model_version: str = "1.0"
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization"""
|
|
data = asdict(self)
|
|
data['creation_time'] = self.creation_time.isoformat()
|
|
data['last_updated'] = self.last_updated.isoformat()
|
|
return data
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
|
|
"""Create from dictionary"""
|
|
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
|
|
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
|
|
data['metrics'] = ModelMetrics(**data['metrics'])
|
|
return cls(**data)
|
|
|
|
class ModelManager:
|
|
"""Enhanced model management system"""
|
|
|
|
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
|
|
self.base_dir = Path(base_dir)
|
|
self.config = config or self._get_default_config()
|
|
|
|
# Model directories
|
|
self.models_dir = self.base_dir / "models"
|
|
self.nn_models_dir = self.base_dir / "NN" / "models"
|
|
self.registry_file = self.models_dir / "model_registry.json"
|
|
self.best_models_dir = self.models_dir / "best_models"
|
|
|
|
# Create directories
|
|
self.best_models_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Model registry
|
|
self.model_registry: Dict[str, ModelInfo] = {}
|
|
self._load_registry()
|
|
|
|
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
|
|
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
|
|
|
|
def _get_default_config(self) -> Dict[str, Any]:
|
|
"""Get default configuration"""
|
|
return {
|
|
'max_models_per_type': 3, # Keep top 3 models per type
|
|
'max_total_models': 10, # Maximum total models to keep
|
|
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
|
|
'min_performance_threshold': 0.3, # Minimum composite score
|
|
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
|
|
'auto_cleanup_enabled': True,
|
|
'backup_before_cleanup': True,
|
|
'model_size_limit_mb': 100, # Individual model size limit
|
|
'total_storage_limit_gb': 5.0 # Total storage limit
|
|
}
|
|
|
|
def _load_registry(self):
|
|
"""Load model registry from file"""
|
|
try:
|
|
if self.registry_file.exists():
|
|
with open(self.registry_file, 'r') as f:
|
|
data = json.load(f)
|
|
self.model_registry = {
|
|
k: ModelInfo.from_dict(v) for k, v in data.items()
|
|
}
|
|
logger.info(f"Loaded {len(self.model_registry)} models from registry")
|
|
else:
|
|
logger.info("No existing model registry found")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model registry: {e}")
|
|
self.model_registry = {}
|
|
|
|
def _save_registry(self):
|
|
"""Save model registry to file"""
|
|
try:
|
|
self.models_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(self.registry_file, 'w') as f:
|
|
data = {k: v.to_dict() for k, v in self.model_registry.items()}
|
|
json.dump(data, f, indent=2, default=str)
|
|
logger.info(f"Saved registry with {len(self.model_registry)} models")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model registry: {e}")
|
|
|
|
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Clean up all existing model files and prepare for 2-action system training
|
|
|
|
Args:
|
|
confirm: If True, perform the cleanup. If False, return what would be cleaned
|
|
|
|
Returns:
|
|
Dict with cleanup statistics
|
|
"""
|
|
cleanup_stats = {
|
|
'files_found': 0,
|
|
'files_deleted': 0,
|
|
'directories_cleaned': 0,
|
|
'space_freed_mb': 0.0,
|
|
'errors': []
|
|
}
|
|
|
|
# Model file patterns for both 2-action and legacy 3-action systems
|
|
model_patterns = [
|
|
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
|
|
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
|
|
]
|
|
|
|
# Directories to clean
|
|
model_directories = [
|
|
"models/saved",
|
|
"NN/models/saved",
|
|
"NN/models/saved/checkpoints",
|
|
"NN/models/saved/realtime_checkpoints",
|
|
"NN/models/saved/realtime_ticks_checkpoints",
|
|
"model_backups"
|
|
]
|
|
|
|
try:
|
|
# Scan for files to be cleaned
|
|
for directory in model_directories:
|
|
dir_path = Path(self.base_dir) / directory
|
|
if dir_path.exists():
|
|
for pattern in model_patterns:
|
|
for file_path in dir_path.glob(pattern):
|
|
if file_path.is_file():
|
|
cleanup_stats['files_found'] += 1
|
|
file_size = file_path.stat().st_size / (1024 * 1024) # MB
|
|
cleanup_stats['space_freed_mb'] += file_size
|
|
|
|
if confirm:
|
|
try:
|
|
file_path.unlink()
|
|
cleanup_stats['files_deleted'] += 1
|
|
logger.info(f"Deleted model file: {file_path}")
|
|
except Exception as e:
|
|
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
|
|
|
|
# Clean up empty checkpoint directories
|
|
for directory in model_directories:
|
|
dir_path = Path(self.base_dir) / directory
|
|
if dir_path.exists():
|
|
for subdir in dir_path.rglob("*"):
|
|
if subdir.is_dir() and not any(subdir.iterdir()):
|
|
if confirm:
|
|
try:
|
|
subdir.rmdir()
|
|
cleanup_stats['directories_cleaned'] += 1
|
|
logger.info(f"Removed empty directory: {subdir}")
|
|
except Exception as e:
|
|
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
|
|
|
|
if confirm:
|
|
# Clear the registry for fresh start with 2-action system
|
|
self.model_registry = {
|
|
'models': {},
|
|
'metadata': {
|
|
'last_updated': datetime.now().isoformat(),
|
|
'total_models': 0,
|
|
'system_type': '2_action', # Mark as 2-action system
|
|
'action_space': ['SELL', 'BUY'],
|
|
'version': '2.0'
|
|
}
|
|
}
|
|
self._save_registry()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
|
|
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
|
|
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
|
|
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
|
|
logger.info("Registry reset for 2-action system (BUY/SELL)")
|
|
logger.info("Ready for fresh training with intelligent position management")
|
|
logger.info("=" * 60)
|
|
else:
|
|
logger.info("=" * 60)
|
|
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
|
|
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
|
|
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
|
|
logger.info("Run with confirm=True to perform cleanup")
|
|
logger.info("=" * 60)
|
|
|
|
except Exception as e:
|
|
cleanup_stats['errors'].append(f"Cleanup error: {e}")
|
|
logger.error(f"Error during model cleanup: {e}")
|
|
|
|
return cleanup_stats
|
|
|
|
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
|
|
"""
|
|
Register a new model in the 2-action system
|
|
|
|
Args:
|
|
model_path: Path to the model file
|
|
model_type: Type of model ('cnn', 'rl', 'transformer')
|
|
metrics: Performance metrics
|
|
|
|
Returns:
|
|
str: Unique model name/ID
|
|
"""
|
|
if not Path(model_path).exists():
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
# Generate unique model name
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
model_name = f"{model_type}_2action_{timestamp}"
|
|
|
|
# Get file info
|
|
file_path = Path(model_path)
|
|
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
|
|
# Default metrics for 2-action system
|
|
if metrics is None:
|
|
metrics = ModelMetrics(
|
|
accuracy=0.0,
|
|
profit_factor=1.0,
|
|
win_rate=0.5,
|
|
sharpe_ratio=0.0,
|
|
max_drawdown=0.0,
|
|
confidence_score=0.5
|
|
)
|
|
|
|
# Create model info
|
|
model_info = ModelInfo(
|
|
model_type=model_type,
|
|
model_name=model_name,
|
|
file_path=str(file_path.absolute()),
|
|
creation_time=datetime.now(),
|
|
last_updated=datetime.now(),
|
|
file_size_mb=file_size_mb,
|
|
metrics=metrics,
|
|
model_version="2.0" # 2-action system version
|
|
)
|
|
|
|
# Add to registry
|
|
self.model_registry['models'][model_name] = model_info.to_dict()
|
|
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
|
|
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
|
|
self.model_registry['metadata']['system_type'] = '2_action'
|
|
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
|
|
|
|
self._save_registry()
|
|
|
|
# Cleanup old models if necessary
|
|
self._cleanup_models_by_type(model_type)
|
|
|
|
logger.info(f"Registered 2-action model: {model_name}")
|
|
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
|
|
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
|
|
|
|
return model_name
|
|
|
|
def _should_keep_model(self, model_info: ModelInfo) -> bool:
|
|
"""Determine if model should be kept based on performance"""
|
|
score = model_info.metrics.get_composite_score()
|
|
|
|
# Check minimum threshold
|
|
if score < self.config['min_performance_threshold']:
|
|
return False
|
|
|
|
# Check size limit
|
|
if model_info.file_size_mb > self.config['model_size_limit_mb']:
|
|
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
|
|
return False
|
|
|
|
# Check if better than existing models of same type
|
|
existing_models = self.get_models_by_type(model_info.model_type)
|
|
if len(existing_models) >= self.config['max_models_per_type']:
|
|
# Find worst performing model
|
|
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
|
|
if score <= worst_model.metrics.get_composite_score():
|
|
return False
|
|
|
|
return True
|
|
|
|
def _cleanup_models_by_type(self, model_type: str):
|
|
"""Cleanup old models of specific type, keeping only the best ones"""
|
|
models_of_type = self.get_models_by_type(model_type)
|
|
max_keep = self.config['max_models_per_type']
|
|
|
|
if len(models_of_type) <= max_keep:
|
|
return
|
|
|
|
# Sort by performance score
|
|
sorted_models = sorted(
|
|
models_of_type.items(),
|
|
key=lambda x: x[1].metrics.get_composite_score(),
|
|
reverse=True
|
|
)
|
|
|
|
# Keep only the best models
|
|
models_to_keep = sorted_models[:max_keep]
|
|
models_to_remove = sorted_models[max_keep:]
|
|
|
|
for model_name, model_info in models_to_remove:
|
|
try:
|
|
# Remove file
|
|
model_path = Path(model_info.file_path)
|
|
if model_path.exists():
|
|
model_path.unlink()
|
|
|
|
# Remove from registry
|
|
del self.model_registry[model_name]
|
|
|
|
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error removing model {model_name}: {e}")
|
|
|
|
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
|
|
"""Get all models of a specific type"""
|
|
return {
|
|
name: info for name, info in self.model_registry.items()
|
|
if info.model_type == model_type
|
|
}
|
|
|
|
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
|
|
"""Get the best performing model of a specific type"""
|
|
models_of_type = self.get_models_by_type(model_type)
|
|
|
|
if not models_of_type:
|
|
return None
|
|
|
|
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
|
|
|
|
def load_best_models(self) -> Dict[str, Any]:
|
|
"""Load the best models for each type"""
|
|
loaded_models = {}
|
|
|
|
for model_type in ['cnn', 'rl', 'transformer']:
|
|
best_model = self.get_best_model(model_type)
|
|
|
|
if best_model:
|
|
try:
|
|
model_path = Path(best_model.file_path)
|
|
if model_path.exists():
|
|
# Load the model
|
|
model_data = torch.load(model_path, map_location='cpu')
|
|
loaded_models[model_type] = {
|
|
'model': model_data,
|
|
'info': best_model,
|
|
'path': str(model_path)
|
|
}
|
|
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
|
|
f"(Score: {best_model.metrics.get_composite_score():.3f})")
|
|
else:
|
|
logger.warning(f"Best {model_type} model file not found: {model_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading {model_type} model: {e}")
|
|
else:
|
|
logger.info(f"No {model_type} model available")
|
|
|
|
return loaded_models
|
|
|
|
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
|
|
"""Update performance metrics for a model"""
|
|
if model_name in self.model_registry:
|
|
self.model_registry[model_name].metrics = metrics
|
|
self.model_registry[model_name].last_updated = datetime.now()
|
|
self._save_registry()
|
|
|
|
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
|
|
else:
|
|
logger.warning(f"Model {model_name} not found in registry")
|
|
|
|
def get_storage_stats(self) -> Dict[str, Any]:
|
|
"""Get storage usage statistics"""
|
|
total_size_mb = 0
|
|
model_count = 0
|
|
|
|
for model_info in self.model_registry.values():
|
|
total_size_mb += model_info.file_size_mb
|
|
model_count += 1
|
|
|
|
# Check actual storage usage
|
|
actual_size_mb = 0
|
|
if self.best_models_dir.exists():
|
|
actual_size_mb = sum(
|
|
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
|
|
) / 1024 / 1024
|
|
|
|
return {
|
|
'total_models': model_count,
|
|
'registered_size_mb': total_size_mb,
|
|
'actual_size_mb': actual_size_mb,
|
|
'storage_limit_gb': self.config['total_storage_limit_gb'],
|
|
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
|
|
'models_by_type': {
|
|
model_type: len(self.get_models_by_type(model_type))
|
|
for model_type in ['cnn', 'rl', 'transformer']
|
|
}
|
|
}
|
|
|
|
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
|
"""Get model performance leaderboard"""
|
|
leaderboard = []
|
|
|
|
for model_name, model_info in self.model_registry.items():
|
|
leaderboard.append({
|
|
'name': model_name,
|
|
'type': model_info.model_type,
|
|
'score': model_info.metrics.get_composite_score(),
|
|
'profit_factor': model_info.metrics.profit_factor,
|
|
'win_rate': model_info.metrics.win_rate,
|
|
'sharpe_ratio': model_info.metrics.sharpe_ratio,
|
|
'size_mb': model_info.file_size_mb,
|
|
'age_days': (datetime.now() - model_info.creation_time).days,
|
|
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
|
|
})
|
|
|
|
# Sort by score
|
|
leaderboard.sort(key=lambda x: x['score'], reverse=True)
|
|
|
|
return leaderboard
|
|
|
|
def cleanup_checkpoints(self) -> Dict[str, Any]:
|
|
"""Clean up old checkpoint files"""
|
|
cleanup_summary = {
|
|
'deleted_files': 0,
|
|
'freed_space_mb': 0,
|
|
'errors': []
|
|
}
|
|
|
|
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
|
|
|
|
# Search for checkpoint files
|
|
checkpoint_patterns = [
|
|
"**/checkpoint_*.pt",
|
|
"**/model_*.pt",
|
|
"**/*checkpoint*",
|
|
"**/epoch_*.pt"
|
|
]
|
|
|
|
for pattern in checkpoint_patterns:
|
|
for file_path in self.base_dir.rglob(pattern):
|
|
if "best_models" not in str(file_path) and file_path.is_file():
|
|
try:
|
|
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
|
if file_time < cutoff_date:
|
|
size_mb = file_path.stat().st_size / 1024 / 1024
|
|
file_path.unlink()
|
|
cleanup_summary['deleted_files'] += 1
|
|
cleanup_summary['freed_space_mb'] += size_mb
|
|
except Exception as e:
|
|
error_msg = f"Error deleting checkpoint {file_path}: {e}"
|
|
logger.error(error_msg)
|
|
cleanup_summary['errors'].append(error_msg)
|
|
|
|
if cleanup_summary['deleted_files'] > 0:
|
|
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
|
|
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
|
|
|
|
return cleanup_summary
|
|
|
|
def create_model_manager() -> ModelManager:
|
|
"""Create and initialize the global model manager"""
|
|
return ModelManager()
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
# Create model manager
|
|
manager = ModelManager()
|
|
|
|
# Clean up all existing models (with confirmation)
|
|
print("WARNING: This will delete ALL existing models!")
|
|
print("Type 'CONFIRM' to proceed:")
|
|
user_input = input().strip()
|
|
|
|
if user_input == "CONFIRM":
|
|
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
|
print(f"\nCleanup complete:")
|
|
print(f"- Deleted {cleanup_result['files_deleted']} files")
|
|
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
|
|
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
|
|
|
|
if cleanup_result['errors']:
|
|
print(f"- {len(cleanup_result['errors'])} errors occurred")
|
|
else:
|
|
print("Cleanup cancelled") |