models load

This commit is contained in:
Dobromir Popov
2025-09-09 00:51:33 +03:00
parent 9671d0d363
commit 0e886527c8

View File

@@ -197,6 +197,10 @@ class ModelManager:
self.nn_models_dir = self.base_dir / "NN" / "models"
self.legacy_models_dir = self.base_dir / "models"
# Legacy checkpoint directories (where existing checkpoints are stored)
self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints"
self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json"
# Metadata and checkpoint management
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
@@ -232,14 +236,72 @@ class ModelManager:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load model metadata"""
"""Load model metadata with legacy support"""
metadata = {'models': {}, 'last_updated': datetime.now().isoformat()}
# First try to load from new unified metadata
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
metadata = json.load(f)
logger.info(f"Loaded unified metadata from {self.metadata_file}")
except Exception as e:
logger.error(f"Error loading metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
logger.error(f"Error loading unified metadata: {e}")
# Also load legacy metadata for backward compatibility
if self.legacy_registry_file.exists():
try:
with open(self.legacy_registry_file, 'r') as f:
legacy_data = json.load(f)
# Merge legacy data into unified metadata
if 'models' in legacy_data:
for model_name, model_info in legacy_data['models'].items():
if model_name not in metadata['models']:
# Convert legacy path format to absolute path
if 'latest_path' in model_info:
legacy_path = model_info['latest_path']
# Handle different legacy path formats
if not legacy_path.startswith('/'):
# Try multiple path resolution strategies
possible_paths = [
self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/...
self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/...
self.base_dir / legacy_path, # /project/models/cnn/...
]
resolved_path = None
for path in possible_paths:
if path.exists():
resolved_path = path
break
if resolved_path:
legacy_path = str(resolved_path)
else:
# If no resolved path found, try to find the file by pattern
filename = Path(legacy_path).name
for search_path in [self.legacy_checkpoints_dir]:
for file_path in search_path.rglob(filename):
legacy_path = str(file_path)
break
metadata['models'][model_name] = {
'type': model_info.get('type', 'unknown'),
'latest_path': legacy_path,
'last_saved': model_info.get('last_saved', 'legacy'),
'save_count': model_info.get('save_count', 1),
'checkpoints': model_info.get('checkpoints', [])
}
logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}")
logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}")
except Exception as e:
logger.error(f"Error loading legacy metadata: {e}")
return metadata
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
"""Load checkpoint metadata"""
@@ -407,34 +469,129 @@ class ModelManager:
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Load the best checkpoint for a model"""
"""Load the best checkpoint for a model with legacy support"""
try:
# First, try the unified registry
model_info = self.metadata['models'].get(model_name)
if model_info and Path(model_info['latest_path']).exists():
# Load from unified registry
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
return model_info['latest_path'], None
logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}")
# Create metadata from model info for compatibility
registry_metadata = CheckpointMetadata(
checkpoint_id=f"{model_name}_registry",
model_name=model_name,
model_type=model_info.get('type', model_name),
file_path=model_info['latest_path'],
created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())),
file_size_mb=0.0, # Will be calculated if needed
performance_score=0.0, # Unknown from registry
accuracy=None,
loss=None, # Orchestrator will handle this
val_accuracy=None,
val_loss=None
)
return model_info['latest_path'], registry_metadata
# Fallback to checkpoint metadata
checkpoints = self.checkpoint_metadata.get(model_name, [])
if not checkpoints:
logger.warning(f"No checkpoints found for {model_name}")
return None
if checkpoints:
# Get best checkpoint
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
# Get best checkpoint
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
if Path(best_checkpoint.file_path).exists():
logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}")
return best_checkpoint.file_path, best_checkpoint
if not Path(best_checkpoint.file_path).exists():
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
return None
# Legacy fallback: Look for checkpoints in legacy directories
logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}")
legacy_path = self._find_legacy_checkpoint(model_name)
if legacy_path:
logger.info(f"Found legacy checkpoint: {legacy_path}")
# Create a basic CheckpointMetadata for the legacy checkpoint
legacy_metadata = CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}",
model_name=model_name,
model_type=model_name, # Will be inferred from model type
file_path=str(legacy_path),
created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime),
file_size_mb=legacy_path.stat().st_size / (1024 * 1024),
performance_score=0.0, # Unknown for legacy
accuracy=None,
loss=None
)
return str(legacy_path), legacy_metadata
return best_checkpoint.file_path, best_checkpoint
logger.warning(f"No checkpoints found for {model_name} in any location")
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]:
"""Find checkpoint in legacy directories"""
if not self.legacy_checkpoints_dir.exists():
return None
# Define search patterns for different model types
# Handle both orchestrator naming and direct model naming
model_patterns = {
'dqn_agent': ['dqn_agent', 'dqn', 'agent'],
'enhanced_cnn': ['cnn_model', 'enhanced_cnn', 'cnn', 'optimized_short_term'],
'cob_rl': ['cob_rl', 'rl', 'rl_agent', 'trading_agent'],
'transformer': ['transformer', 'decision'],
'decision': ['decision', 'transformer'],
# Also support direct model names
'dqn': ['dqn_agent', 'dqn', 'agent'],
'cnn': ['cnn_model', 'cnn', 'optimized_short_term'],
'rl': ['cob_rl', 'rl', 'rl_agent']
}
# Get patterns for this model name or use generic patterns
patterns = model_patterns.get(model_name, [model_name])
# Search in legacy saved directory first
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
if legacy_saved_dir.exists():
for file_path in legacy_saved_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Search in model-specific directories
for model_type in model_patterns.keys():
model_dir = self.legacy_checkpoints_dir / model_type
if model_dir.exists():
saved_dir = model_dir / "saved"
if saved_dir.exists():
for file_path in saved_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Search in archive directory
archive_dir = self.legacy_checkpoints_dir / "archive"
if archive_dir.exists():
for file_path in archive_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Search in backtest directory (might contain RL or other models)
backtest_dir = self.legacy_checkpoints_dir / "backtest"
if backtest_dir.exists():
for file_path in backtest_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Last resort: search entire legacy directory
for file_path in self.legacy_checkpoints_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
return None
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
@@ -467,7 +624,7 @@ class ModelManager:
'models': {}
}
# Count files in different directories as "checkpoints"
# Count files in new unified directories
checkpoint_dirs = [
self.checkpoints_dir / "cnn",
self.checkpoints_dir / "dqn",
@@ -511,6 +668,34 @@ class ModelManager:
saved_size = sum(f.stat().st_size for f in saved_files)
stats['total_size_mb'] += saved_size / (1024 * 1024)
# Add legacy checkpoint statistics
if self.legacy_checkpoints_dir.exists():
legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt'))
if legacy_files:
legacy_size = sum(f.stat().st_size for f in legacy_files)
stats['total_checkpoints'] += len(legacy_files)
stats['total_size_mb'] += legacy_size / (1024 * 1024)
# Add legacy models to stats
legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision']
for model_dir_name in legacy_model_dirs:
model_dir = self.legacy_checkpoints_dir / model_dir_name
if model_dir.exists():
model_files = list(model_dir.rglob('*.pt'))
if model_files and model_dir_name not in stats['models']:
stats['total_models'] += 1
model_size = sum(f.stat().st_size for f in model_files)
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
stats['models'][model_dir_name] = {
'checkpoint_count': len(model_files),
'total_size_mb': model_size / (1024 * 1024),
'best_performance': 0.0,
'best_checkpoint_id': latest_file.name,
'latest_checkpoint': latest_file.name,
'location': 'legacy'
}
return stats
except Exception as e: