models load
This commit is contained in:
@@ -197,6 +197,10 @@ class ModelManager:
|
|||||||
self.nn_models_dir = self.base_dir / "NN" / "models"
|
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||||
self.legacy_models_dir = self.base_dir / "models"
|
self.legacy_models_dir = self.base_dir / "models"
|
||||||
|
|
||||||
|
# Legacy checkpoint directories (where existing checkpoints are stored)
|
||||||
|
self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints"
|
||||||
|
self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json"
|
||||||
|
|
||||||
# Metadata and checkpoint management
|
# Metadata and checkpoint management
|
||||||
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
|
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
|
||||||
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
|
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
|
||||||
@@ -232,14 +236,72 @@ class ModelManager:
|
|||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _load_metadata(self) -> Dict[str, Any]:
|
def _load_metadata(self) -> Dict[str, Any]:
|
||||||
"""Load model metadata"""
|
"""Load model metadata with legacy support"""
|
||||||
|
metadata = {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||||
|
|
||||||
|
# First try to load from new unified metadata
|
||||||
if self.metadata_file.exists():
|
if self.metadata_file.exists():
|
||||||
try:
|
try:
|
||||||
with open(self.metadata_file, 'r') as f:
|
with open(self.metadata_file, 'r') as f:
|
||||||
return json.load(f)
|
metadata = json.load(f)
|
||||||
|
logger.info(f"Loaded unified metadata from {self.metadata_file}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading metadata: {e}")
|
logger.error(f"Error loading unified metadata: {e}")
|
||||||
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
|
||||||
|
# Also load legacy metadata for backward compatibility
|
||||||
|
if self.legacy_registry_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.legacy_registry_file, 'r') as f:
|
||||||
|
legacy_data = json.load(f)
|
||||||
|
|
||||||
|
# Merge legacy data into unified metadata
|
||||||
|
if 'models' in legacy_data:
|
||||||
|
for model_name, model_info in legacy_data['models'].items():
|
||||||
|
if model_name not in metadata['models']:
|
||||||
|
# Convert legacy path format to absolute path
|
||||||
|
if 'latest_path' in model_info:
|
||||||
|
legacy_path = model_info['latest_path']
|
||||||
|
|
||||||
|
# Handle different legacy path formats
|
||||||
|
if not legacy_path.startswith('/'):
|
||||||
|
# Try multiple path resolution strategies
|
||||||
|
possible_paths = [
|
||||||
|
self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/...
|
||||||
|
self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/...
|
||||||
|
self.base_dir / legacy_path, # /project/models/cnn/...
|
||||||
|
]
|
||||||
|
|
||||||
|
resolved_path = None
|
||||||
|
for path in possible_paths:
|
||||||
|
if path.exists():
|
||||||
|
resolved_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
if resolved_path:
|
||||||
|
legacy_path = str(resolved_path)
|
||||||
|
else:
|
||||||
|
# If no resolved path found, try to find the file by pattern
|
||||||
|
filename = Path(legacy_path).name
|
||||||
|
for search_path in [self.legacy_checkpoints_dir]:
|
||||||
|
for file_path in search_path.rglob(filename):
|
||||||
|
legacy_path = str(file_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
metadata['models'][model_name] = {
|
||||||
|
'type': model_info.get('type', 'unknown'),
|
||||||
|
'latest_path': legacy_path,
|
||||||
|
'last_saved': model_info.get('last_saved', 'legacy'),
|
||||||
|
'save_count': model_info.get('save_count', 1),
|
||||||
|
'checkpoints': model_info.get('checkpoints', [])
|
||||||
|
}
|
||||||
|
logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}")
|
||||||
|
|
||||||
|
logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading legacy metadata: {e}")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
|
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""Load checkpoint metadata"""
|
"""Load checkpoint metadata"""
|
||||||
@@ -407,34 +469,129 @@ class ModelManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||||
"""Load the best checkpoint for a model"""
|
"""Load the best checkpoint for a model with legacy support"""
|
||||||
try:
|
try:
|
||||||
# First, try the unified registry
|
# First, try the unified registry
|
||||||
model_info = self.metadata['models'].get(model_name)
|
model_info = self.metadata['models'].get(model_name)
|
||||||
if model_info and Path(model_info['latest_path']).exists():
|
if model_info and Path(model_info['latest_path']).exists():
|
||||||
# Load from unified registry
|
logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}")
|
||||||
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
|
# Create metadata from model info for compatibility
|
||||||
return model_info['latest_path'], None
|
registry_metadata = CheckpointMetadata(
|
||||||
|
checkpoint_id=f"{model_name}_registry",
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_info.get('type', model_name),
|
||||||
|
file_path=model_info['latest_path'],
|
||||||
|
created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())),
|
||||||
|
file_size_mb=0.0, # Will be calculated if needed
|
||||||
|
performance_score=0.0, # Unknown from registry
|
||||||
|
accuracy=None,
|
||||||
|
loss=None, # Orchestrator will handle this
|
||||||
|
val_accuracy=None,
|
||||||
|
val_loss=None
|
||||||
|
)
|
||||||
|
return model_info['latest_path'], registry_metadata
|
||||||
|
|
||||||
# Fallback to checkpoint metadata
|
# Fallback to checkpoint metadata
|
||||||
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
||||||
if not checkpoints:
|
if checkpoints:
|
||||||
logger.warning(f"No checkpoints found for {model_name}")
|
# Get best checkpoint
|
||||||
return None
|
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
||||||
|
|
||||||
# Get best checkpoint
|
if Path(best_checkpoint.file_path).exists():
|
||||||
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}")
|
||||||
|
return best_checkpoint.file_path, best_checkpoint
|
||||||
|
|
||||||
if not Path(best_checkpoint.file_path).exists():
|
# Legacy fallback: Look for checkpoints in legacy directories
|
||||||
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}")
|
||||||
return None
|
legacy_path = self._find_legacy_checkpoint(model_name)
|
||||||
|
if legacy_path:
|
||||||
|
logger.info(f"Found legacy checkpoint: {legacy_path}")
|
||||||
|
# Create a basic CheckpointMetadata for the legacy checkpoint
|
||||||
|
legacy_metadata = CheckpointMetadata(
|
||||||
|
checkpoint_id=f"legacy_{model_name}",
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_name, # Will be inferred from model type
|
||||||
|
file_path=str(legacy_path),
|
||||||
|
created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime),
|
||||||
|
file_size_mb=legacy_path.stat().st_size / (1024 * 1024),
|
||||||
|
performance_score=0.0, # Unknown for legacy
|
||||||
|
accuracy=None,
|
||||||
|
loss=None
|
||||||
|
)
|
||||||
|
return str(legacy_path), legacy_metadata
|
||||||
|
|
||||||
return best_checkpoint.file_path, best_checkpoint
|
logger.warning(f"No checkpoints found for {model_name} in any location")
|
||||||
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]:
|
||||||
|
"""Find checkpoint in legacy directories"""
|
||||||
|
if not self.legacy_checkpoints_dir.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Define search patterns for different model types
|
||||||
|
# Handle both orchestrator naming and direct model naming
|
||||||
|
model_patterns = {
|
||||||
|
'dqn_agent': ['dqn_agent', 'dqn', 'agent'],
|
||||||
|
'enhanced_cnn': ['cnn_model', 'enhanced_cnn', 'cnn', 'optimized_short_term'],
|
||||||
|
'cob_rl': ['cob_rl', 'rl', 'rl_agent', 'trading_agent'],
|
||||||
|
'transformer': ['transformer', 'decision'],
|
||||||
|
'decision': ['decision', 'transformer'],
|
||||||
|
# Also support direct model names
|
||||||
|
'dqn': ['dqn_agent', 'dqn', 'agent'],
|
||||||
|
'cnn': ['cnn_model', 'cnn', 'optimized_short_term'],
|
||||||
|
'rl': ['cob_rl', 'rl', 'rl_agent']
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get patterns for this model name or use generic patterns
|
||||||
|
patterns = model_patterns.get(model_name, [model_name])
|
||||||
|
|
||||||
|
# Search in legacy saved directory first
|
||||||
|
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
|
||||||
|
if legacy_saved_dir.exists():
|
||||||
|
for file_path in legacy_saved_dir.rglob("*.pt"):
|
||||||
|
filename = file_path.name.lower()
|
||||||
|
if any(pattern in filename for pattern in patterns):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
# Search in model-specific directories
|
||||||
|
for model_type in model_patterns.keys():
|
||||||
|
model_dir = self.legacy_checkpoints_dir / model_type
|
||||||
|
if model_dir.exists():
|
||||||
|
saved_dir = model_dir / "saved"
|
||||||
|
if saved_dir.exists():
|
||||||
|
for file_path in saved_dir.rglob("*.pt"):
|
||||||
|
filename = file_path.name.lower()
|
||||||
|
if any(pattern in filename for pattern in patterns):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
# Search in archive directory
|
||||||
|
archive_dir = self.legacy_checkpoints_dir / "archive"
|
||||||
|
if archive_dir.exists():
|
||||||
|
for file_path in archive_dir.rglob("*.pt"):
|
||||||
|
filename = file_path.name.lower()
|
||||||
|
if any(pattern in filename for pattern in patterns):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
# Search in backtest directory (might contain RL or other models)
|
||||||
|
backtest_dir = self.legacy_checkpoints_dir / "backtest"
|
||||||
|
if backtest_dir.exists():
|
||||||
|
for file_path in backtest_dir.rglob("*.pt"):
|
||||||
|
filename = file_path.name.lower()
|
||||||
|
if any(pattern in filename for pattern in patterns):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
# Last resort: search entire legacy directory
|
||||||
|
for file_path in self.legacy_checkpoints_dir.rglob("*.pt"):
|
||||||
|
filename = file_path.name.lower()
|
||||||
|
if any(pattern in filename for pattern in patterns):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_storage_stats(self) -> Dict[str, Any]:
|
def get_storage_stats(self) -> Dict[str, Any]:
|
||||||
"""Get storage statistics"""
|
"""Get storage statistics"""
|
||||||
try:
|
try:
|
||||||
@@ -467,7 +624,7 @@ class ModelManager:
|
|||||||
'models': {}
|
'models': {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Count files in different directories as "checkpoints"
|
# Count files in new unified directories
|
||||||
checkpoint_dirs = [
|
checkpoint_dirs = [
|
||||||
self.checkpoints_dir / "cnn",
|
self.checkpoints_dir / "cnn",
|
||||||
self.checkpoints_dir / "dqn",
|
self.checkpoints_dir / "dqn",
|
||||||
@@ -511,6 +668,34 @@ class ModelManager:
|
|||||||
saved_size = sum(f.stat().st_size for f in saved_files)
|
saved_size = sum(f.stat().st_size for f in saved_files)
|
||||||
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
||||||
|
|
||||||
|
# Add legacy checkpoint statistics
|
||||||
|
if self.legacy_checkpoints_dir.exists():
|
||||||
|
legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt'))
|
||||||
|
if legacy_files:
|
||||||
|
legacy_size = sum(f.stat().st_size for f in legacy_files)
|
||||||
|
stats['total_checkpoints'] += len(legacy_files)
|
||||||
|
stats['total_size_mb'] += legacy_size / (1024 * 1024)
|
||||||
|
|
||||||
|
# Add legacy models to stats
|
||||||
|
legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision']
|
||||||
|
for model_dir_name in legacy_model_dirs:
|
||||||
|
model_dir = self.legacy_checkpoints_dir / model_dir_name
|
||||||
|
if model_dir.exists():
|
||||||
|
model_files = list(model_dir.rglob('*.pt'))
|
||||||
|
if model_files and model_dir_name not in stats['models']:
|
||||||
|
stats['total_models'] += 1
|
||||||
|
model_size = sum(f.stat().st_size for f in model_files)
|
||||||
|
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
||||||
|
|
||||||
|
stats['models'][model_dir_name] = {
|
||||||
|
'checkpoint_count': len(model_files),
|
||||||
|
'total_size_mb': model_size / (1024 * 1024),
|
||||||
|
'best_performance': 0.0,
|
||||||
|
'best_checkpoint_id': latest_file.name,
|
||||||
|
'latest_checkpoint': latest_file.name,
|
||||||
|
'location': 'legacy'
|
||||||
|
}
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user