783 lines
33 KiB
Python
783 lines
33 KiB
Python
"""
|
|
Unified Model Management System for Trading Dashboard
|
|
|
|
CONSOLIDATED SYSTEM - All model management functionality in one place
|
|
|
|
This system provides:
|
|
- Automatic cleanup of old model checkpoints
|
|
- Best model tracking with performance metrics
|
|
- Configurable retention policies
|
|
- Startup model loading
|
|
- Performance-based model selection
|
|
- Robust model saving with multiple fallback strategies
|
|
- Checkpoint management with W&B integration
|
|
- Centralized storage using @checkpoints/ structure
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import shutil
|
|
import logging
|
|
import torch
|
|
import glob
|
|
import pickle
|
|
import hashlib
|
|
import random
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Dict, Any, Optional, List, Tuple, Union
|
|
from collections import defaultdict
|
|
|
|
# W&B import (optional)
|
|
try:
|
|
import wandb
|
|
WANDB_AVAILABLE = True
|
|
except ImportError:
|
|
WANDB_AVAILABLE = False
|
|
wandb = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ModelMetrics:
|
|
"""Enhanced performance metrics for model evaluation"""
|
|
accuracy: float = 0.0
|
|
profit_factor: float = 0.0
|
|
win_rate: float = 0.0
|
|
sharpe_ratio: float = 0.0
|
|
max_drawdown: float = 0.0
|
|
total_trades: int = 0
|
|
avg_trade_duration: float = 0.0
|
|
confidence_score: float = 0.0
|
|
|
|
# Additional metrics from checkpoint_manager
|
|
loss: Optional[float] = None
|
|
val_accuracy: Optional[float] = None
|
|
val_loss: Optional[float] = None
|
|
reward: Optional[float] = None
|
|
pnl: Optional[float] = None
|
|
epoch: Optional[int] = None
|
|
training_time_hours: Optional[float] = None
|
|
total_parameters: Optional[int] = None
|
|
|
|
def get_composite_score(self) -> float:
|
|
"""Calculate composite performance score"""
|
|
# Weighted composite score
|
|
weights = {
|
|
'profit_factor': 0.25,
|
|
'sharpe_ratio': 0.2,
|
|
'win_rate': 0.15,
|
|
'accuracy': 0.15,
|
|
'confidence_score': 0.1,
|
|
'loss_penalty': 0.1, # New: penalize high loss
|
|
'val_penalty': 0.05 # New: penalize validation loss
|
|
}
|
|
|
|
# Normalize values to 0-1 range
|
|
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
|
|
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
|
|
normalized_win_rate = self.win_rate
|
|
normalized_accuracy = self.accuracy
|
|
normalized_confidence = self.confidence_score
|
|
|
|
# Loss penalty (lower loss = higher score)
|
|
loss_penalty = 1.0
|
|
if self.loss is not None and self.loss > 0:
|
|
loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty
|
|
|
|
# Validation penalty
|
|
val_penalty = 1.0
|
|
if self.val_loss is not None and self.val_loss > 0:
|
|
val_penalty = max(0.1, 1 / (1 + self.val_loss))
|
|
|
|
# Apply penalties for poor performance
|
|
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
|
|
|
|
score = (
|
|
weights['profit_factor'] * normalized_pf +
|
|
weights['sharpe_ratio'] * normalized_sharpe +
|
|
weights['win_rate'] * normalized_win_rate +
|
|
weights['accuracy'] * normalized_accuracy +
|
|
weights['confidence_score'] * normalized_confidence +
|
|
weights['loss_penalty'] * loss_penalty +
|
|
weights['val_penalty'] * val_penalty
|
|
) * drawdown_penalty
|
|
|
|
return min(max(score, 0), 1)
|
|
|
|
|
|
@dataclass
|
|
class ModelInfo:
|
|
"""Model information tracking"""
|
|
model_type: str # 'cnn', 'rl', 'transformer'
|
|
model_name: str
|
|
file_path: str
|
|
creation_time: datetime
|
|
last_updated: datetime
|
|
file_size_mb: float
|
|
metrics: ModelMetrics
|
|
training_episodes: int = 0
|
|
model_version: str = "1.0"
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization"""
|
|
data = asdict(self)
|
|
data['creation_time'] = self.creation_time.isoformat()
|
|
data['last_updated'] = self.last_updated.isoformat()
|
|
return data
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
|
|
"""Create from dictionary"""
|
|
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
|
|
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
|
|
data['metrics'] = ModelMetrics(**data['metrics'])
|
|
return cls(**data)
|
|
|
|
|
|
@dataclass
|
|
class CheckpointMetadata:
|
|
checkpoint_id: str
|
|
model_name: str
|
|
model_type: str
|
|
file_path: str
|
|
created_at: datetime
|
|
file_size_mb: float
|
|
performance_score: float
|
|
accuracy: Optional[float] = None
|
|
loss: Optional[float] = None
|
|
val_accuracy: Optional[float] = None
|
|
val_loss: Optional[float] = None
|
|
reward: Optional[float] = None
|
|
pnl: Optional[float] = None
|
|
epoch: Optional[int] = None
|
|
training_time_hours: Optional[float] = None
|
|
total_parameters: Optional[int] = None
|
|
wandb_run_id: Optional[str] = None
|
|
wandb_artifact_name: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
data = asdict(self)
|
|
data['created_at'] = self.created_at.isoformat()
|
|
return data
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
|
|
data['created_at'] = datetime.fromisoformat(data['created_at'])
|
|
return cls(**data)
|
|
|
|
|
|
class ModelManager:
|
|
"""Unified model management system with @checkpoints/ structure"""
|
|
|
|
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
|
|
self.base_dir = Path(base_dir)
|
|
self.config = config or self._get_default_config()
|
|
|
|
# Updated directory structure using @checkpoints/
|
|
self.checkpoints_dir = self.base_dir / "@checkpoints"
|
|
self.models_dir = self.checkpoints_dir / "models"
|
|
self.saved_dir = self.checkpoints_dir / "saved"
|
|
self.best_models_dir = self.checkpoints_dir / "best_models"
|
|
self.archive_dir = self.checkpoints_dir / "archive"
|
|
|
|
# Model type directories within @checkpoints/
|
|
self.model_dirs = {
|
|
'cnn': self.checkpoints_dir / "cnn",
|
|
'dqn': self.checkpoints_dir / "dqn",
|
|
'rl': self.checkpoints_dir / "rl",
|
|
'transformer': self.checkpoints_dir / "transformer",
|
|
'hybrid': self.checkpoints_dir / "hybrid"
|
|
}
|
|
|
|
# Legacy directories for backward compatibility
|
|
self.nn_models_dir = self.base_dir / "NN" / "models"
|
|
self.legacy_models_dir = self.base_dir / "models"
|
|
|
|
# Legacy checkpoint directories (where existing checkpoints are stored)
|
|
self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints"
|
|
self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json"
|
|
|
|
# Metadata and checkpoint management
|
|
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
|
|
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
|
|
|
|
# Initialize storage
|
|
self._initialize_directories()
|
|
self.metadata = self._load_metadata()
|
|
self.checkpoint_metadata = self._load_checkpoint_metadata()
|
|
|
|
logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}")
|
|
|
|
def _get_default_config(self) -> Dict[str, Any]:
|
|
"""Get default configuration"""
|
|
return {
|
|
'max_checkpoints_per_model': 5,
|
|
'cleanup_old_models': True,
|
|
'auto_archive': True,
|
|
'wandb_enabled': WANDB_AVAILABLE,
|
|
'checkpoint_retention_days': 30
|
|
}
|
|
|
|
def _initialize_directories(self):
|
|
"""Initialize directory structure"""
|
|
directories = [
|
|
self.checkpoints_dir,
|
|
self.models_dir,
|
|
self.saved_dir,
|
|
self.best_models_dir,
|
|
self.archive_dir
|
|
] + list(self.model_dirs.values())
|
|
|
|
for directory in directories:
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _load_metadata(self) -> Dict[str, Any]:
|
|
"""Load model metadata with legacy support"""
|
|
metadata = {'models': {}, 'last_updated': datetime.now().isoformat()}
|
|
|
|
# First try to load from new unified metadata
|
|
if self.metadata_file.exists():
|
|
try:
|
|
with open(self.metadata_file, 'r') as f:
|
|
metadata = json.load(f)
|
|
logger.info(f"Loaded unified metadata from {self.metadata_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading unified metadata: {e}")
|
|
|
|
# Also load legacy metadata for backward compatibility
|
|
if self.legacy_registry_file.exists():
|
|
try:
|
|
with open(self.legacy_registry_file, 'r') as f:
|
|
legacy_data = json.load(f)
|
|
|
|
# Merge legacy data into unified metadata
|
|
if 'models' in legacy_data:
|
|
for model_name, model_info in legacy_data['models'].items():
|
|
if model_name not in metadata['models']:
|
|
# Convert legacy path format to absolute path
|
|
if 'latest_path' in model_info:
|
|
legacy_path = model_info['latest_path']
|
|
|
|
# Handle different legacy path formats
|
|
if not legacy_path.startswith('/'):
|
|
# Try multiple path resolution strategies
|
|
possible_paths = [
|
|
self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/...
|
|
self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/...
|
|
self.base_dir / legacy_path, # /project/models/cnn/...
|
|
]
|
|
|
|
resolved_path = None
|
|
for path in possible_paths:
|
|
if path.exists():
|
|
resolved_path = path
|
|
break
|
|
|
|
if resolved_path:
|
|
legacy_path = str(resolved_path)
|
|
else:
|
|
# If no resolved path found, try to find the file by pattern
|
|
filename = Path(legacy_path).name
|
|
for search_path in [self.legacy_checkpoints_dir]:
|
|
for file_path in search_path.rglob(filename):
|
|
legacy_path = str(file_path)
|
|
break
|
|
|
|
metadata['models'][model_name] = {
|
|
'type': model_info.get('type', 'unknown'),
|
|
'latest_path': legacy_path,
|
|
'last_saved': model_info.get('last_saved', 'legacy'),
|
|
'save_count': model_info.get('save_count', 1),
|
|
'checkpoints': model_info.get('checkpoints', [])
|
|
}
|
|
logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}")
|
|
|
|
logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading legacy metadata: {e}")
|
|
|
|
return metadata
|
|
|
|
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
|
|
"""Load checkpoint metadata"""
|
|
if self.checkpoint_metadata_file.exists():
|
|
try:
|
|
with open(self.checkpoint_metadata_file, 'r') as f:
|
|
data = json.load(f)
|
|
# Convert dict values back to CheckpointMetadata objects
|
|
result = {}
|
|
for key, checkpoints in data.items():
|
|
result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints]
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoint metadata: {e}")
|
|
return defaultdict(list)
|
|
|
|
def save_checkpoint(self, model, model_name: str, model_type: str,
|
|
performance_metrics: Dict[str, float],
|
|
training_metadata: Optional[Dict[str, Any]] = None,
|
|
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
|
"""Save a model checkpoint with enhanced error handling and validation"""
|
|
try:
|
|
performance_score = self._calculate_performance_score(performance_metrics)
|
|
|
|
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
|
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
|
return None
|
|
|
|
# Create checkpoint directory
|
|
checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints"
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate checkpoint filename
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
checkpoint_id = f"{model_name}_{timestamp}"
|
|
filename = f"{checkpoint_id}.pt"
|
|
filepath = checkpoint_dir / filename
|
|
|
|
# Save model
|
|
save_dict = {
|
|
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
|
'model_class': model.__class__.__name__,
|
|
'checkpoint_id': checkpoint_id,
|
|
'model_name': model_name,
|
|
'model_type': model_type,
|
|
'performance_score': performance_score,
|
|
'performance_metrics': performance_metrics,
|
|
'training_metadata': training_metadata or {},
|
|
'created_at': datetime.now().isoformat(),
|
|
'version': '2.0'
|
|
}
|
|
|
|
torch.save(save_dict, filepath)
|
|
|
|
# Create checkpoint metadata
|
|
file_size_mb = filepath.stat().st_size / (1024 * 1024)
|
|
metadata = CheckpointMetadata(
|
|
checkpoint_id=checkpoint_id,
|
|
model_name=model_name,
|
|
model_type=model_type,
|
|
file_path=str(filepath),
|
|
created_at=datetime.now(),
|
|
file_size_mb=file_size_mb,
|
|
performance_score=performance_score,
|
|
accuracy=performance_metrics.get('accuracy'),
|
|
loss=performance_metrics.get('loss'),
|
|
val_accuracy=performance_metrics.get('val_accuracy'),
|
|
val_loss=performance_metrics.get('val_loss'),
|
|
reward=performance_metrics.get('reward'),
|
|
pnl=performance_metrics.get('pnl'),
|
|
epoch=performance_metrics.get('epoch'),
|
|
training_time_hours=performance_metrics.get('training_time_hours'),
|
|
total_parameters=performance_metrics.get('total_parameters')
|
|
)
|
|
|
|
# Store metadata
|
|
self.checkpoint_metadata[model_name].append(metadata)
|
|
self._save_checkpoint_metadata()
|
|
|
|
# Rotate checkpoints if needed
|
|
self._rotate_checkpoints(model_name)
|
|
|
|
# Upload to W&B if enabled
|
|
if self.config.get('wandb_enabled'):
|
|
self._upload_to_wandb(metadata)
|
|
|
|
logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})")
|
|
return metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
|
return None
|
|
|
|
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
|
|
"""Calculate performance score from metrics"""
|
|
# Simple weighted score - can be enhanced
|
|
weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1}
|
|
score = 0.0
|
|
for metric, weight in weights.items():
|
|
if metric in metrics:
|
|
score += metrics[metric] * weight
|
|
return score
|
|
|
|
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
|
"""Determine if checkpoint should be saved"""
|
|
existing_checkpoints = self.checkpoint_metadata.get(model_name, [])
|
|
if not existing_checkpoints:
|
|
return True
|
|
|
|
# Keep if better than worst checkpoint or if we have fewer than max
|
|
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
|
|
if len(existing_checkpoints) < max_checkpoints:
|
|
return True
|
|
|
|
worst_score = min(cp.performance_score for cp in existing_checkpoints)
|
|
return performance_score > worst_score
|
|
|
|
def _rotate_checkpoints(self, model_name: str):
|
|
"""Rotate checkpoints to maintain max count"""
|
|
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
|
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
|
|
|
|
if len(checkpoints) <= max_checkpoints:
|
|
return
|
|
|
|
# Sort by performance score (descending)
|
|
checkpoints.sort(key=lambda x: x.performance_score, reverse=True)
|
|
|
|
# Remove excess checkpoints
|
|
to_remove = checkpoints[max_checkpoints:]
|
|
for checkpoint in to_remove:
|
|
try:
|
|
Path(checkpoint.file_path).unlink(missing_ok=True)
|
|
logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}")
|
|
|
|
# Update metadata
|
|
self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints]
|
|
self._save_checkpoint_metadata()
|
|
|
|
def _save_checkpoint_metadata(self):
|
|
"""Save checkpoint metadata to file"""
|
|
try:
|
|
data = {}
|
|
for model_name, checkpoints in self.checkpoint_metadata.items():
|
|
data[model_name] = [cp.to_dict() for cp in checkpoints]
|
|
|
|
with open(self.checkpoint_metadata_file, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
except Exception as e:
|
|
logger.error(f"Error saving checkpoint metadata: {e}")
|
|
|
|
def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]:
|
|
"""Upload checkpoint to W&B"""
|
|
if not WANDB_AVAILABLE:
|
|
return None
|
|
|
|
try:
|
|
# This would be implemented based on your W&B workflow
|
|
logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error uploading to W&B: {e}")
|
|
return None
|
|
|
|
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
|
"""Load the best checkpoint for a model with legacy support"""
|
|
try:
|
|
# First, try the unified registry
|
|
model_info = self.metadata['models'].get(model_name)
|
|
if model_info and Path(model_info['latest_path']).exists():
|
|
logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}")
|
|
# Create metadata from model info for compatibility
|
|
registry_metadata = CheckpointMetadata(
|
|
checkpoint_id=f"{model_name}_registry",
|
|
model_name=model_name,
|
|
model_type=model_info.get('type', model_name),
|
|
file_path=model_info['latest_path'],
|
|
created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())),
|
|
file_size_mb=0.0, # Will be calculated if needed
|
|
performance_score=0.0, # Unknown from registry
|
|
accuracy=None,
|
|
loss=None, # Orchestrator will handle this
|
|
val_accuracy=None,
|
|
val_loss=None
|
|
)
|
|
return model_info['latest_path'], registry_metadata
|
|
|
|
# Fallback to checkpoint metadata
|
|
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
|
if checkpoints:
|
|
# Get best checkpoint
|
|
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
|
|
|
if Path(best_checkpoint.file_path).exists():
|
|
logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}")
|
|
return best_checkpoint.file_path, best_checkpoint
|
|
|
|
# Legacy fallback: Look for checkpoints in legacy directories
|
|
logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}")
|
|
legacy_path = self._find_legacy_checkpoint(model_name)
|
|
if legacy_path:
|
|
logger.info(f"Found legacy checkpoint: {legacy_path}")
|
|
# Create a basic CheckpointMetadata for the legacy checkpoint
|
|
legacy_metadata = CheckpointMetadata(
|
|
checkpoint_id=f"legacy_{model_name}",
|
|
model_name=model_name,
|
|
model_type=model_name, # Will be inferred from model type
|
|
file_path=str(legacy_path),
|
|
created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime),
|
|
file_size_mb=legacy_path.stat().st_size / (1024 * 1024),
|
|
performance_score=0.0, # Unknown for legacy
|
|
accuracy=None,
|
|
loss=None
|
|
)
|
|
return str(legacy_path), legacy_metadata
|
|
|
|
logger.warning(f"No checkpoints found for {model_name} in any location")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
|
return None
|
|
|
|
def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]:
|
|
"""Find checkpoint in legacy directories"""
|
|
if not self.legacy_checkpoints_dir.exists():
|
|
return None
|
|
|
|
# Use unified model naming throughout the project
|
|
# All model references use consistent short names: dqn, cnn, cob_rl, transformer, decision
|
|
# This eliminates complex mapping and ensures consistency across the entire codebase
|
|
patterns = [model_name]
|
|
|
|
# Add minimal backward compatibility patterns
|
|
if model_name == 'dqn':
|
|
patterns.extend(['dqn_agent', 'agent'])
|
|
elif model_name == 'cnn':
|
|
patterns.extend(['cnn_model', 'enhanced_cnn'])
|
|
elif model_name == 'cob_rl':
|
|
patterns.extend(['rl', 'rl_agent', 'trading_agent'])
|
|
|
|
# Search in legacy saved directory first
|
|
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
|
|
if legacy_saved_dir.exists():
|
|
for file_path in legacy_saved_dir.rglob("*.pt"):
|
|
filename = file_path.name.lower()
|
|
if any(pattern in filename for pattern in patterns):
|
|
return file_path
|
|
|
|
# Search in model-specific directories
|
|
for model_type in ['cnn', 'dqn', 'rl', 'transformer', 'decision']:
|
|
model_dir = self.legacy_checkpoints_dir / model_type
|
|
if model_dir.exists():
|
|
saved_dir = model_dir / "saved"
|
|
if saved_dir.exists():
|
|
for file_path in saved_dir.rglob("*.pt"):
|
|
filename = file_path.name.lower()
|
|
if any(pattern in filename for pattern in patterns):
|
|
return file_path
|
|
|
|
# Search in archive directory
|
|
archive_dir = self.legacy_checkpoints_dir / "archive"
|
|
if archive_dir.exists():
|
|
for file_path in archive_dir.rglob("*.pt"):
|
|
filename = file_path.name.lower()
|
|
if any(pattern in filename for pattern in patterns):
|
|
return file_path
|
|
|
|
# Search in backtest directory (might contain RL or other models)
|
|
backtest_dir = self.legacy_checkpoints_dir / "backtest"
|
|
if backtest_dir.exists():
|
|
for file_path in backtest_dir.rglob("*.pt"):
|
|
filename = file_path.name.lower()
|
|
if any(pattern in filename for pattern in patterns):
|
|
return file_path
|
|
|
|
# Last resort: search entire legacy directory
|
|
for file_path in self.legacy_checkpoints_dir.rglob("*.pt"):
|
|
filename = file_path.name.lower()
|
|
if any(pattern in filename for pattern in patterns):
|
|
return file_path
|
|
|
|
return None
|
|
|
|
def get_storage_stats(self) -> Dict[str, Any]:
|
|
"""Get storage statistics"""
|
|
try:
|
|
total_size = 0
|
|
file_count = 0
|
|
|
|
for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]:
|
|
if directory.exists():
|
|
for file_path in directory.rglob('*'):
|
|
if file_path.is_file():
|
|
total_size += file_path.stat().st_size
|
|
file_count += 1
|
|
|
|
return {
|
|
'total_size_mb': total_size / (1024 * 1024),
|
|
'file_count': file_count,
|
|
'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting storage stats: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)"""
|
|
try:
|
|
stats = {
|
|
'total_models': 0,
|
|
'total_checkpoints': 0,
|
|
'total_size_mb': 0.0,
|
|
'models': {}
|
|
}
|
|
|
|
# Count files in new unified directories
|
|
checkpoint_dirs = [
|
|
self.checkpoints_dir / "cnn",
|
|
self.checkpoints_dir / "dqn",
|
|
self.checkpoints_dir / "rl",
|
|
self.checkpoints_dir / "transformer",
|
|
self.checkpoints_dir / "hybrid"
|
|
]
|
|
|
|
total_size = 0
|
|
total_files = 0
|
|
|
|
for checkpoint_dir in checkpoint_dirs:
|
|
if checkpoint_dir.exists():
|
|
model_files = list(checkpoint_dir.rglob('*.pt'))
|
|
if model_files:
|
|
model_name = checkpoint_dir.name
|
|
stats['total_models'] += 1
|
|
|
|
model_size = sum(f.stat().st_size for f in model_files)
|
|
stats['total_checkpoints'] += len(model_files)
|
|
stats['total_size_mb'] += model_size / (1024 * 1024)
|
|
total_size += model_size
|
|
total_files += len(model_files)
|
|
|
|
# Get the most recent file as "latest"
|
|
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
|
|
|
stats['models'][model_name] = {
|
|
'checkpoint_count': len(model_files),
|
|
'total_size_mb': model_size / (1024 * 1024),
|
|
'best_performance': 0.0, # Not tracked in unified system
|
|
'best_checkpoint_id': latest_file.name,
|
|
'latest_checkpoint': latest_file.name
|
|
}
|
|
|
|
# Also check saved models directory
|
|
if self.saved_dir.exists():
|
|
saved_files = list(self.saved_dir.rglob('*.pt'))
|
|
if saved_files:
|
|
stats['total_checkpoints'] += len(saved_files)
|
|
saved_size = sum(f.stat().st_size for f in saved_files)
|
|
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
|
|
|
# Add legacy checkpoint statistics
|
|
if self.legacy_checkpoints_dir.exists():
|
|
legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt'))
|
|
if legacy_files:
|
|
legacy_size = sum(f.stat().st_size for f in legacy_files)
|
|
stats['total_checkpoints'] += len(legacy_files)
|
|
stats['total_size_mb'] += legacy_size / (1024 * 1024)
|
|
|
|
# Add legacy models to stats
|
|
legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision']
|
|
for model_dir_name in legacy_model_dirs:
|
|
model_dir = self.legacy_checkpoints_dir / model_dir_name
|
|
if model_dir.exists():
|
|
model_files = list(model_dir.rglob('*.pt'))
|
|
if model_files and model_dir_name not in stats['models']:
|
|
stats['total_models'] += 1
|
|
model_size = sum(f.stat().st_size for f in model_files)
|
|
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
|
|
|
stats['models'][model_dir_name] = {
|
|
'checkpoint_count': len(model_files),
|
|
'total_size_mb': model_size / (1024 * 1024),
|
|
'best_performance': 0.0,
|
|
'best_checkpoint_id': latest_file.name,
|
|
'latest_checkpoint': latest_file.name,
|
|
'location': 'legacy'
|
|
}
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting checkpoint stats: {e}")
|
|
return {
|
|
'total_models': 0,
|
|
'total_checkpoints': 0,
|
|
'total_size_mb': 0.0,
|
|
'models': {},
|
|
'error': str(e)
|
|
}
|
|
|
|
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
|
"""Get model performance leaderboard"""
|
|
try:
|
|
leaderboard = []
|
|
|
|
for model_name, model_info in self.metadata['models'].items():
|
|
if 'metrics' in model_info:
|
|
metrics = ModelMetrics(**model_info['metrics'])
|
|
leaderboard.append({
|
|
'model_name': model_name,
|
|
'model_type': model_info.get('model_type', 'unknown'),
|
|
'composite_score': metrics.get_composite_score(),
|
|
'accuracy': metrics.accuracy,
|
|
'profit_factor': metrics.profit_factor,
|
|
'win_rate': metrics.win_rate,
|
|
'last_updated': model_info.get('last_saved', 'unknown')
|
|
})
|
|
|
|
# Sort by composite score
|
|
leaderboard.sort(key=lambda x: x['composite_score'], reverse=True)
|
|
return leaderboard
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting leaderboard: {e}")
|
|
return []
|
|
|
|
|
|
# ===== LEGACY COMPATIBILITY FUNCTIONS =====
|
|
|
|
def create_model_manager() -> ModelManager:
|
|
"""Create and return a ModelManager instance"""
|
|
return ModelManager()
|
|
|
|
|
|
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""Legacy compatibility function to save a model"""
|
|
manager = create_model_manager()
|
|
return manager.save_model(model, model_name, model_type, metadata)
|
|
|
|
|
|
def load_model(model_name: str, model_type: str = 'cnn',
|
|
model_class: Optional[Any] = None) -> Optional[Any]:
|
|
"""Legacy compatibility function to load a model"""
|
|
manager = create_model_manager()
|
|
return manager.load_model(model_name, model_type, model_class)
|
|
|
|
|
|
def save_checkpoint(model, model_name: str, model_type: str,
|
|
performance_metrics: Dict[str, float],
|
|
training_metadata: Optional[Dict[str, Any]] = None,
|
|
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
|
"""Legacy compatibility function to save a checkpoint"""
|
|
manager = create_model_manager()
|
|
return manager.save_checkpoint(model, model_name, model_type,
|
|
performance_metrics, training_metadata, force_save)
|
|
|
|
|
|
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
|
"""Legacy compatibility function to load the best checkpoint"""
|
|
manager = create_model_manager()
|
|
return manager.load_best_checkpoint(model_name)
|
|
|
|
|
|
# ===== EXAMPLE USAGE =====
|
|
if __name__ == "__main__":
|
|
# Example usage of the unified model manager
|
|
manager = create_model_manager()
|
|
print(f"ModelManager initialized at: {manager.checkpoints_dir}")
|
|
|
|
# Get storage stats
|
|
stats = manager.get_storage_stats()
|
|
print(f"Storage stats: {stats}")
|
|
|
|
# Get leaderboard
|
|
leaderboard = manager.get_model_leaderboard()
|
|
print(f"Models in leaderboard: {len(leaderboard)}") |