models load
This commit is contained in:
@@ -197,6 +197,10 @@ class ModelManager:
|
||||
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||
self.legacy_models_dir = self.base_dir / "models"
|
||||
|
||||
# Legacy checkpoint directories (where existing checkpoints are stored)
|
||||
self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints"
|
||||
self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json"
|
||||
|
||||
# Metadata and checkpoint management
|
||||
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
|
||||
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
|
||||
@@ -232,14 +236,72 @@ class ModelManager:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_metadata(self) -> Dict[str, Any]:
|
||||
"""Load model metadata"""
|
||||
"""Load model metadata with legacy support"""
|
||||
metadata = {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||
|
||||
# First try to load from new unified metadata
|
||||
if self.metadata_file.exists():
|
||||
try:
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
metadata = json.load(f)
|
||||
logger.info(f"Loaded unified metadata from {self.metadata_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata: {e}")
|
||||
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||
logger.error(f"Error loading unified metadata: {e}")
|
||||
|
||||
# Also load legacy metadata for backward compatibility
|
||||
if self.legacy_registry_file.exists():
|
||||
try:
|
||||
with open(self.legacy_registry_file, 'r') as f:
|
||||
legacy_data = json.load(f)
|
||||
|
||||
# Merge legacy data into unified metadata
|
||||
if 'models' in legacy_data:
|
||||
for model_name, model_info in legacy_data['models'].items():
|
||||
if model_name not in metadata['models']:
|
||||
# Convert legacy path format to absolute path
|
||||
if 'latest_path' in model_info:
|
||||
legacy_path = model_info['latest_path']
|
||||
|
||||
# Handle different legacy path formats
|
||||
if not legacy_path.startswith('/'):
|
||||
# Try multiple path resolution strategies
|
||||
possible_paths = [
|
||||
self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/...
|
||||
self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/...
|
||||
self.base_dir / legacy_path, # /project/models/cnn/...
|
||||
]
|
||||
|
||||
resolved_path = None
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
resolved_path = path
|
||||
break
|
||||
|
||||
if resolved_path:
|
||||
legacy_path = str(resolved_path)
|
||||
else:
|
||||
# If no resolved path found, try to find the file by pattern
|
||||
filename = Path(legacy_path).name
|
||||
for search_path in [self.legacy_checkpoints_dir]:
|
||||
for file_path in search_path.rglob(filename):
|
||||
legacy_path = str(file_path)
|
||||
break
|
||||
|
||||
metadata['models'][model_name] = {
|
||||
'type': model_info.get('type', 'unknown'),
|
||||
'latest_path': legacy_path,
|
||||
'last_saved': model_info.get('last_saved', 'legacy'),
|
||||
'save_count': model_info.get('save_count', 1),
|
||||
'checkpoints': model_info.get('checkpoints', [])
|
||||
}
|
||||
logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}")
|
||||
|
||||
logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading legacy metadata: {e}")
|
||||
|
||||
return metadata
|
||||
|
||||
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Load checkpoint metadata"""
|
||||
@@ -407,34 +469,129 @@ class ModelManager:
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
"""Load the best checkpoint for a model"""
|
||||
"""Load the best checkpoint for a model with legacy support"""
|
||||
try:
|
||||
# First, try the unified registry
|
||||
model_info = self.metadata['models'].get(model_name)
|
||||
if model_info and Path(model_info['latest_path']).exists():
|
||||
# Load from unified registry
|
||||
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
|
||||
return model_info['latest_path'], None
|
||||
logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}")
|
||||
# Create metadata from model info for compatibility
|
||||
registry_metadata = CheckpointMetadata(
|
||||
checkpoint_id=f"{model_name}_registry",
|
||||
model_name=model_name,
|
||||
model_type=model_info.get('type', model_name),
|
||||
file_path=model_info['latest_path'],
|
||||
created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())),
|
||||
file_size_mb=0.0, # Will be calculated if needed
|
||||
performance_score=0.0, # Unknown from registry
|
||||
accuracy=None,
|
||||
loss=None, # Orchestrator will handle this
|
||||
val_accuracy=None,
|
||||
val_loss=None
|
||||
)
|
||||
return model_info['latest_path'], registry_metadata
|
||||
|
||||
# Fallback to checkpoint metadata
|
||||
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
||||
if not checkpoints:
|
||||
logger.warning(f"No checkpoints found for {model_name}")
|
||||
return None
|
||||
if checkpoints:
|
||||
# Get best checkpoint
|
||||
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
||||
|
||||
# Get best checkpoint
|
||||
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
||||
if Path(best_checkpoint.file_path).exists():
|
||||
logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
|
||||
if not Path(best_checkpoint.file_path).exists():
|
||||
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
||||
return None
|
||||
# Legacy fallback: Look for checkpoints in legacy directories
|
||||
logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}")
|
||||
legacy_path = self._find_legacy_checkpoint(model_name)
|
||||
if legacy_path:
|
||||
logger.info(f"Found legacy checkpoint: {legacy_path}")
|
||||
# Create a basic CheckpointMetadata for the legacy checkpoint
|
||||
legacy_metadata = CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}",
|
||||
model_name=model_name,
|
||||
model_type=model_name, # Will be inferred from model type
|
||||
file_path=str(legacy_path),
|
||||
created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime),
|
||||
file_size_mb=legacy_path.stat().st_size / (1024 * 1024),
|
||||
performance_score=0.0, # Unknown for legacy
|
||||
accuracy=None,
|
||||
loss=None
|
||||
)
|
||||
return str(legacy_path), legacy_metadata
|
||||
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
logger.warning(f"No checkpoints found for {model_name} in any location")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]:
|
||||
"""Find checkpoint in legacy directories"""
|
||||
if not self.legacy_checkpoints_dir.exists():
|
||||
return None
|
||||
|
||||
# Define search patterns for different model types
|
||||
# Handle both orchestrator naming and direct model naming
|
||||
model_patterns = {
|
||||
'dqn_agent': ['dqn_agent', 'dqn', 'agent'],
|
||||
'enhanced_cnn': ['cnn_model', 'enhanced_cnn', 'cnn', 'optimized_short_term'],
|
||||
'cob_rl': ['cob_rl', 'rl', 'rl_agent', 'trading_agent'],
|
||||
'transformer': ['transformer', 'decision'],
|
||||
'decision': ['decision', 'transformer'],
|
||||
# Also support direct model names
|
||||
'dqn': ['dqn_agent', 'dqn', 'agent'],
|
||||
'cnn': ['cnn_model', 'cnn', 'optimized_short_term'],
|
||||
'rl': ['cob_rl', 'rl', 'rl_agent']
|
||||
}
|
||||
|
||||
# Get patterns for this model name or use generic patterns
|
||||
patterns = model_patterns.get(model_name, [model_name])
|
||||
|
||||
# Search in legacy saved directory first
|
||||
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
|
||||
if legacy_saved_dir.exists():
|
||||
for file_path in legacy_saved_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Search in model-specific directories
|
||||
for model_type in model_patterns.keys():
|
||||
model_dir = self.legacy_checkpoints_dir / model_type
|
||||
if model_dir.exists():
|
||||
saved_dir = model_dir / "saved"
|
||||
if saved_dir.exists():
|
||||
for file_path in saved_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Search in archive directory
|
||||
archive_dir = self.legacy_checkpoints_dir / "archive"
|
||||
if archive_dir.exists():
|
||||
for file_path in archive_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Search in backtest directory (might contain RL or other models)
|
||||
backtest_dir = self.legacy_checkpoints_dir / "backtest"
|
||||
if backtest_dir.exists():
|
||||
for file_path in backtest_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Last resort: search entire legacy directory
|
||||
for file_path in self.legacy_checkpoints_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
return None
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage statistics"""
|
||||
try:
|
||||
@@ -467,7 +624,7 @@ class ModelManager:
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Count files in different directories as "checkpoints"
|
||||
# Count files in new unified directories
|
||||
checkpoint_dirs = [
|
||||
self.checkpoints_dir / "cnn",
|
||||
self.checkpoints_dir / "dqn",
|
||||
@@ -511,6 +668,34 @@ class ModelManager:
|
||||
saved_size = sum(f.stat().st_size for f in saved_files)
|
||||
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
||||
|
||||
# Add legacy checkpoint statistics
|
||||
if self.legacy_checkpoints_dir.exists():
|
||||
legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt'))
|
||||
if legacy_files:
|
||||
legacy_size = sum(f.stat().st_size for f in legacy_files)
|
||||
stats['total_checkpoints'] += len(legacy_files)
|
||||
stats['total_size_mb'] += legacy_size / (1024 * 1024)
|
||||
|
||||
# Add legacy models to stats
|
||||
legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision']
|
||||
for model_dir_name in legacy_model_dirs:
|
||||
model_dir = self.legacy_checkpoints_dir / model_dir_name
|
||||
if model_dir.exists():
|
||||
model_files = list(model_dir.rglob('*.pt'))
|
||||
if model_files and model_dir_name not in stats['models']:
|
||||
stats['total_models'] += 1
|
||||
model_size = sum(f.stat().st_size for f in model_files)
|
||||
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
||||
|
||||
stats['models'][model_dir_name] = {
|
||||
'checkpoint_count': len(model_files),
|
||||
'total_size_mb': model_size / (1024 * 1024),
|
||||
'best_performance': 0.0,
|
||||
'best_checkpoint_id': latest_file.name,
|
||||
'latest_checkpoint': latest_file.name,
|
||||
'location': 'legacy'
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user