added leverage slider
This commit is contained in:
558
model_manager.py
Normal file
558
model_manager.py
Normal file
@ -0,0 +1,558 @@
|
||||
"""
|
||||
Enhanced Model Management System for Trading Dashboard
|
||||
|
||||
This system provides:
|
||||
- Automatic cleanup of old model checkpoints
|
||||
- Best model tracking with performance metrics
|
||||
- Configurable retention policies
|
||||
- Startup model loading
|
||||
- Performance-based model selection
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import logging
|
||||
import torch
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelMetrics:
|
||||
"""Performance metrics for model evaluation"""
|
||||
accuracy: float = 0.0
|
||||
profit_factor: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
sharpe_ratio: float = 0.0
|
||||
max_drawdown: float = 0.0
|
||||
total_trades: int = 0
|
||||
avg_trade_duration: float = 0.0
|
||||
confidence_score: float = 0.0
|
||||
|
||||
def get_composite_score(self) -> float:
|
||||
"""Calculate composite performance score"""
|
||||
# Weighted composite score
|
||||
weights = {
|
||||
'profit_factor': 0.3,
|
||||
'sharpe_ratio': 0.25,
|
||||
'win_rate': 0.2,
|
||||
'accuracy': 0.15,
|
||||
'confidence_score': 0.1
|
||||
}
|
||||
|
||||
# Normalize values to 0-1 range
|
||||
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
|
||||
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
|
||||
normalized_win_rate = self.win_rate
|
||||
normalized_accuracy = self.accuracy
|
||||
normalized_confidence = self.confidence_score
|
||||
|
||||
# Apply penalties for poor performance
|
||||
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
|
||||
|
||||
score = (
|
||||
weights['profit_factor'] * normalized_pf +
|
||||
weights['sharpe_ratio'] * normalized_sharpe +
|
||||
weights['win_rate'] * normalized_win_rate +
|
||||
weights['accuracy'] * normalized_accuracy +
|
||||
weights['confidence_score'] * normalized_confidence
|
||||
) * drawdown_penalty
|
||||
|
||||
return min(max(score, 0), 1)
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Complete model information and metadata"""
|
||||
model_type: str # 'cnn', 'rl', 'transformer'
|
||||
model_name: str
|
||||
file_path: str
|
||||
creation_time: datetime
|
||||
last_updated: datetime
|
||||
file_size_mb: float
|
||||
metrics: ModelMetrics
|
||||
training_episodes: int = 0
|
||||
model_version: str = "1.0"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
data = asdict(self)
|
||||
data['creation_time'] = self.creation_time.isoformat()
|
||||
data['last_updated'] = self.last_updated.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
|
||||
"""Create from dictionary"""
|
||||
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
|
||||
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
|
||||
data['metrics'] = ModelMetrics(**data['metrics'])
|
||||
return cls(**data)
|
||||
|
||||
class ModelManager:
|
||||
"""Enhanced model management system"""
|
||||
|
||||
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
|
||||
self.base_dir = Path(base_dir)
|
||||
self.config = config or self._get_default_config()
|
||||
|
||||
# Model directories
|
||||
self.models_dir = self.base_dir / "models"
|
||||
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||
self.registry_file = self.models_dir / "model_registry.json"
|
||||
self.best_models_dir = self.models_dir / "best_models"
|
||||
|
||||
# Create directories
|
||||
self.best_models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model registry
|
||||
self.model_registry: Dict[str, ModelInfo] = {}
|
||||
self._load_registry()
|
||||
|
||||
logger.info(f"Model Manager initialized - Base: {self.base_dir}")
|
||||
logger.info(f"Retention policy: Keep {self.config['max_models_per_type']} best models per type")
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
'max_models_per_type': 3, # Keep top 3 models per type
|
||||
'max_total_models': 10, # Maximum total models to keep
|
||||
'cleanup_frequency_hours': 24, # Cleanup every 24 hours
|
||||
'min_performance_threshold': 0.3, # Minimum composite score
|
||||
'max_checkpoint_age_days': 7, # Delete checkpoints older than 7 days
|
||||
'auto_cleanup_enabled': True,
|
||||
'backup_before_cleanup': True,
|
||||
'model_size_limit_mb': 100, # Individual model size limit
|
||||
'total_storage_limit_gb': 5.0 # Total storage limit
|
||||
}
|
||||
|
||||
def _load_registry(self):
|
||||
"""Load model registry from file"""
|
||||
try:
|
||||
if self.registry_file.exists():
|
||||
with open(self.registry_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.model_registry = {
|
||||
k: ModelInfo.from_dict(v) for k, v in data.items()
|
||||
}
|
||||
logger.info(f"Loaded {len(self.model_registry)} models from registry")
|
||||
else:
|
||||
logger.info("No existing model registry found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model registry: {e}")
|
||||
self.model_registry = {}
|
||||
|
||||
def _save_registry(self):
|
||||
"""Save model registry to file"""
|
||||
try:
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.registry_file, 'w') as f:
|
||||
data = {k: v.to_dict() for k, v in self.model_registry.items()}
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
logger.info(f"Saved registry with {len(self.model_registry)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model registry: {e}")
|
||||
|
||||
def cleanup_all_existing_models(self, confirm: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean up all existing model files and prepare for 2-action system training
|
||||
|
||||
Args:
|
||||
confirm: If True, perform the cleanup. If False, return what would be cleaned
|
||||
|
||||
Returns:
|
||||
Dict with cleanup statistics
|
||||
"""
|
||||
cleanup_stats = {
|
||||
'files_found': 0,
|
||||
'files_deleted': 0,
|
||||
'directories_cleaned': 0,
|
||||
'space_freed_mb': 0.0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
# Model file patterns for both 2-action and legacy 3-action systems
|
||||
model_patterns = [
|
||||
"**/*.pt", "**/*.pth", "**/*.h5", "**/*.pkl", "**/*.joblib", "**/*.model",
|
||||
"**/checkpoint_*", "**/model_*", "**/cnn_*", "**/dqn_*", "**/rl_*"
|
||||
]
|
||||
|
||||
# Directories to clean
|
||||
model_directories = [
|
||||
"models/saved",
|
||||
"NN/models/saved",
|
||||
"NN/models/saved/checkpoints",
|
||||
"NN/models/saved/realtime_checkpoints",
|
||||
"NN/models/saved/realtime_ticks_checkpoints",
|
||||
"model_backups"
|
||||
]
|
||||
|
||||
try:
|
||||
# Scan for files to be cleaned
|
||||
for directory in model_directories:
|
||||
dir_path = Path(self.base_dir) / directory
|
||||
if dir_path.exists():
|
||||
for pattern in model_patterns:
|
||||
for file_path in dir_path.glob(pattern):
|
||||
if file_path.is_file():
|
||||
cleanup_stats['files_found'] += 1
|
||||
file_size = file_path.stat().st_size / (1024 * 1024) # MB
|
||||
cleanup_stats['space_freed_mb'] += file_size
|
||||
|
||||
if confirm:
|
||||
try:
|
||||
file_path.unlink()
|
||||
cleanup_stats['files_deleted'] += 1
|
||||
logger.info(f"Deleted model file: {file_path}")
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
# Clean up empty checkpoint directories
|
||||
for directory in model_directories:
|
||||
dir_path = Path(self.base_dir) / directory
|
||||
if dir_path.exists():
|
||||
for subdir in dir_path.rglob("*"):
|
||||
if subdir.is_dir() and not any(subdir.iterdir()):
|
||||
if confirm:
|
||||
try:
|
||||
subdir.rmdir()
|
||||
cleanup_stats['directories_cleaned'] += 1
|
||||
logger.info(f"Removed empty directory: {subdir}")
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Failed to remove directory {subdir}: {e}")
|
||||
|
||||
if confirm:
|
||||
# Clear the registry for fresh start with 2-action system
|
||||
self.model_registry = {
|
||||
'models': {},
|
||||
'metadata': {
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'total_models': 0,
|
||||
'system_type': '2_action', # Mark as 2-action system
|
||||
'action_space': ['SELL', 'BUY'],
|
||||
'version': '2.0'
|
||||
}
|
||||
}
|
||||
self._save_registry()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL CLEANUP COMPLETED - 2-ACTION SYSTEM READY")
|
||||
logger.info(f"Files deleted: {cleanup_stats['files_deleted']}")
|
||||
logger.info(f"Space freed: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||
logger.info(f"Directories cleaned: {cleanup_stats['directories_cleaned']}")
|
||||
logger.info("Registry reset for 2-action system (BUY/SELL)")
|
||||
logger.info("Ready for fresh training with intelligent position management")
|
||||
logger.info("=" * 60)
|
||||
else:
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL CLEANUP PREVIEW - 2-ACTION SYSTEM MIGRATION")
|
||||
logger.info(f"Files to delete: {cleanup_stats['files_found']}")
|
||||
logger.info(f"Space to free: {cleanup_stats['space_freed_mb']:.2f} MB")
|
||||
logger.info("Run with confirm=True to perform cleanup")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
cleanup_stats['errors'].append(f"Cleanup error: {e}")
|
||||
logger.error(f"Error during model cleanup: {e}")
|
||||
|
||||
return cleanup_stats
|
||||
|
||||
def register_model(self, model_path: str, model_type: str, metrics: Optional[ModelMetrics] = None) -> str:
|
||||
"""
|
||||
Register a new model in the 2-action system
|
||||
|
||||
Args:
|
||||
model_path: Path to the model file
|
||||
model_type: Type of model ('cnn', 'rl', 'transformer')
|
||||
metrics: Performance metrics
|
||||
|
||||
Returns:
|
||||
str: Unique model name/ID
|
||||
"""
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Generate unique model name
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_name = f"{model_type}_2action_{timestamp}"
|
||||
|
||||
# Get file info
|
||||
file_path = Path(model_path)
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Default metrics for 2-action system
|
||||
if metrics is None:
|
||||
metrics = ModelMetrics(
|
||||
accuracy=0.0,
|
||||
profit_factor=1.0,
|
||||
win_rate=0.5,
|
||||
sharpe_ratio=0.0,
|
||||
max_drawdown=0.0,
|
||||
confidence_score=0.5
|
||||
)
|
||||
|
||||
# Create model info
|
||||
model_info = ModelInfo(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
file_path=str(file_path.absolute()),
|
||||
creation_time=datetime.now(),
|
||||
last_updated=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
metrics=metrics,
|
||||
model_version="2.0" # 2-action system version
|
||||
)
|
||||
|
||||
# Add to registry
|
||||
self.model_registry['models'][model_name] = model_info.to_dict()
|
||||
self.model_registry['metadata']['total_models'] = len(self.model_registry['models'])
|
||||
self.model_registry['metadata']['last_updated'] = datetime.now().isoformat()
|
||||
self.model_registry['metadata']['system_type'] = '2_action'
|
||||
self.model_registry['metadata']['action_space'] = ['SELL', 'BUY']
|
||||
|
||||
self._save_registry()
|
||||
|
||||
# Cleanup old models if necessary
|
||||
self._cleanup_models_by_type(model_type)
|
||||
|
||||
logger.info(f"Registered 2-action model: {model_name}")
|
||||
logger.info(f"Model type: {model_type}, Size: {file_size_mb:.2f} MB")
|
||||
logger.info(f"Performance score: {metrics.get_composite_score():.4f}")
|
||||
|
||||
return model_name
|
||||
|
||||
def _should_keep_model(self, model_info: ModelInfo) -> bool:
|
||||
"""Determine if model should be kept based on performance"""
|
||||
score = model_info.metrics.get_composite_score()
|
||||
|
||||
# Check minimum threshold
|
||||
if score < self.config['min_performance_threshold']:
|
||||
return False
|
||||
|
||||
# Check size limit
|
||||
if model_info.file_size_mb > self.config['model_size_limit_mb']:
|
||||
logger.warning(f"Model too large: {model_info.file_size_mb:.1f}MB > {self.config['model_size_limit_mb']}MB")
|
||||
return False
|
||||
|
||||
# Check if better than existing models of same type
|
||||
existing_models = self.get_models_by_type(model_info.model_type)
|
||||
if len(existing_models) >= self.config['max_models_per_type']:
|
||||
# Find worst performing model
|
||||
worst_model = min(existing_models.values(), key=lambda m: m.metrics.get_composite_score())
|
||||
if score <= worst_model.metrics.get_composite_score():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _cleanup_models_by_type(self, model_type: str):
|
||||
"""Cleanup old models of specific type, keeping only the best ones"""
|
||||
models_of_type = self.get_models_by_type(model_type)
|
||||
max_keep = self.config['max_models_per_type']
|
||||
|
||||
if len(models_of_type) <= max_keep:
|
||||
return
|
||||
|
||||
# Sort by performance score
|
||||
sorted_models = sorted(
|
||||
models_of_type.items(),
|
||||
key=lambda x: x[1].metrics.get_composite_score(),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Keep only the best models
|
||||
models_to_keep = sorted_models[:max_keep]
|
||||
models_to_remove = sorted_models[max_keep:]
|
||||
|
||||
for model_name, model_info in models_to_remove:
|
||||
try:
|
||||
# Remove file
|
||||
model_path = Path(model_info.file_path)
|
||||
if model_path.exists():
|
||||
model_path.unlink()
|
||||
|
||||
# Remove from registry
|
||||
del self.model_registry[model_name]
|
||||
|
||||
logger.info(f"Removed old model: {model_name} (Score: {model_info.metrics.get_composite_score():.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing model {model_name}: {e}")
|
||||
|
||||
def get_models_by_type(self, model_type: str) -> Dict[str, ModelInfo]:
|
||||
"""Get all models of a specific type"""
|
||||
return {
|
||||
name: info for name, info in self.model_registry.items()
|
||||
if info.model_type == model_type
|
||||
}
|
||||
|
||||
def get_best_model(self, model_type: str) -> Optional[ModelInfo]:
|
||||
"""Get the best performing model of a specific type"""
|
||||
models_of_type = self.get_models_by_type(model_type)
|
||||
|
||||
if not models_of_type:
|
||||
return None
|
||||
|
||||
return max(models_of_type.values(), key=lambda m: m.metrics.get_composite_score())
|
||||
|
||||
def load_best_models(self) -> Dict[str, Any]:
|
||||
"""Load the best models for each type"""
|
||||
loaded_models = {}
|
||||
|
||||
for model_type in ['cnn', 'rl', 'transformer']:
|
||||
best_model = self.get_best_model(model_type)
|
||||
|
||||
if best_model:
|
||||
try:
|
||||
model_path = Path(best_model.file_path)
|
||||
if model_path.exists():
|
||||
# Load the model
|
||||
model_data = torch.load(model_path, map_location='cpu')
|
||||
loaded_models[model_type] = {
|
||||
'model': model_data,
|
||||
'info': best_model,
|
||||
'path': str(model_path)
|
||||
}
|
||||
logger.info(f"Loaded best {model_type} model: {best_model.model_name} "
|
||||
f"(Score: {best_model.metrics.get_composite_score():.3f})")
|
||||
else:
|
||||
logger.warning(f"Best {model_type} model file not found: {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {model_type} model: {e}")
|
||||
else:
|
||||
logger.info(f"No {model_type} model available")
|
||||
|
||||
return loaded_models
|
||||
|
||||
def update_model_performance(self, model_name: str, metrics: ModelMetrics):
|
||||
"""Update performance metrics for a model"""
|
||||
if model_name in self.model_registry:
|
||||
self.model_registry[model_name].metrics = metrics
|
||||
self.model_registry[model_name].last_updated = datetime.now()
|
||||
self._save_registry()
|
||||
|
||||
logger.info(f"Updated metrics for {model_name}: Score {metrics.get_composite_score():.3f}")
|
||||
else:
|
||||
logger.warning(f"Model {model_name} not found in registry")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage usage statistics"""
|
||||
total_size_mb = 0
|
||||
model_count = 0
|
||||
|
||||
for model_info in self.model_registry.values():
|
||||
total_size_mb += model_info.file_size_mb
|
||||
model_count += 1
|
||||
|
||||
# Check actual storage usage
|
||||
actual_size_mb = 0
|
||||
if self.best_models_dir.exists():
|
||||
actual_size_mb = sum(
|
||||
f.stat().st_size for f in self.best_models_dir.rglob('*') if f.is_file()
|
||||
) / 1024 / 1024
|
||||
|
||||
return {
|
||||
'total_models': model_count,
|
||||
'registered_size_mb': total_size_mb,
|
||||
'actual_size_mb': actual_size_mb,
|
||||
'storage_limit_gb': self.config['total_storage_limit_gb'],
|
||||
'utilization_percent': (actual_size_mb / 1024) / self.config['total_storage_limit_gb'] * 100,
|
||||
'models_by_type': {
|
||||
model_type: len(self.get_models_by_type(model_type))
|
||||
for model_type in ['cnn', 'rl', 'transformer']
|
||||
}
|
||||
}
|
||||
|
||||
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
||||
"""Get model performance leaderboard"""
|
||||
leaderboard = []
|
||||
|
||||
for model_name, model_info in self.model_registry.items():
|
||||
leaderboard.append({
|
||||
'name': model_name,
|
||||
'type': model_info.model_type,
|
||||
'score': model_info.metrics.get_composite_score(),
|
||||
'profit_factor': model_info.metrics.profit_factor,
|
||||
'win_rate': model_info.metrics.win_rate,
|
||||
'sharpe_ratio': model_info.metrics.sharpe_ratio,
|
||||
'size_mb': model_info.file_size_mb,
|
||||
'age_days': (datetime.now() - model_info.creation_time).days,
|
||||
'last_updated': model_info.last_updated.strftime('%Y-%m-%d %H:%M')
|
||||
})
|
||||
|
||||
# Sort by score
|
||||
leaderboard.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
return leaderboard
|
||||
|
||||
def cleanup_checkpoints(self) -> Dict[str, Any]:
|
||||
"""Clean up old checkpoint files"""
|
||||
cleanup_summary = {
|
||||
'deleted_files': 0,
|
||||
'freed_space_mb': 0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
cutoff_date = datetime.now() - timedelta(days=self.config['max_checkpoint_age_days'])
|
||||
|
||||
# Search for checkpoint files
|
||||
checkpoint_patterns = [
|
||||
"**/checkpoint_*.pt",
|
||||
"**/model_*.pt",
|
||||
"**/*checkpoint*",
|
||||
"**/epoch_*.pt"
|
||||
]
|
||||
|
||||
for pattern in checkpoint_patterns:
|
||||
for file_path in self.base_dir.rglob(pattern):
|
||||
if "best_models" not in str(file_path) and file_path.is_file():
|
||||
try:
|
||||
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
if file_time < cutoff_date:
|
||||
size_mb = file_path.stat().st_size / 1024 / 1024
|
||||
file_path.unlink()
|
||||
cleanup_summary['deleted_files'] += 1
|
||||
cleanup_summary['freed_space_mb'] += size_mb
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting checkpoint {file_path}: {e}"
|
||||
logger.error(error_msg)
|
||||
cleanup_summary['errors'].append(error_msg)
|
||||
|
||||
if cleanup_summary['deleted_files'] > 0:
|
||||
logger.info(f"Checkpoint cleanup: Deleted {cleanup_summary['deleted_files']} files, "
|
||||
f"freed {cleanup_summary['freed_space_mb']:.1f}MB")
|
||||
|
||||
return cleanup_summary
|
||||
|
||||
def create_model_manager() -> ModelManager:
|
||||
"""Create and initialize the global model manager"""
|
||||
return ModelManager()
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create model manager
|
||||
manager = ModelManager()
|
||||
|
||||
# Clean up all existing models (with confirmation)
|
||||
print("WARNING: This will delete ALL existing models!")
|
||||
print("Type 'CONFIRM' to proceed:")
|
||||
user_input = input().strip()
|
||||
|
||||
if user_input == "CONFIRM":
|
||||
cleanup_result = manager.cleanup_all_existing_models(confirm=True)
|
||||
print(f"\nCleanup complete:")
|
||||
print(f"- Deleted {cleanup_result['files_deleted']} files")
|
||||
print(f"- Freed {cleanup_result['space_freed_mb']:.1f}MB of space")
|
||||
print(f"- Cleaned {cleanup_result['directories_cleaned']} directories")
|
||||
|
||||
if cleanup_result['errors']:
|
||||
print(f"- {len(cleanup_result['errors'])} errors occurred")
|
||||
else:
|
||||
print("Cleanup cancelled")
|
Reference in New Issue
Block a user