unify model names

This commit is contained in:
Dobromir Popov
2025-09-09 01:10:35 +03:00
parent 0e886527c8
commit 317c703ea0
5 changed files with 229 additions and 104 deletions

View File

@@ -532,22 +532,18 @@ class ModelManager:
if not self.legacy_checkpoints_dir.exists():
return None
# Define search patterns for different model types
# Handle both orchestrator naming and direct model naming
model_patterns = {
'dqn_agent': ['dqn_agent', 'dqn', 'agent'],
'enhanced_cnn': ['cnn_model', 'enhanced_cnn', 'cnn', 'optimized_short_term'],
'cob_rl': ['cob_rl', 'rl', 'rl_agent', 'trading_agent'],
'transformer': ['transformer', 'decision'],
'decision': ['decision', 'transformer'],
# Also support direct model names
'dqn': ['dqn_agent', 'dqn', 'agent'],
'cnn': ['cnn_model', 'cnn', 'optimized_short_term'],
'rl': ['cob_rl', 'rl', 'rl_agent']
}
# Use unified model naming throughout the project
# All model references use consistent short names: dqn, cnn, cob_rl, transformer, decision
# This eliminates complex mapping and ensures consistency across the entire codebase
patterns = [model_name]
# Get patterns for this model name or use generic patterns
patterns = model_patterns.get(model_name, [model_name])
# Add minimal backward compatibility patterns
if model_name == 'dqn':
patterns.extend(['dqn_agent', 'agent'])
elif model_name == 'cnn':
patterns.extend(['cnn_model', 'enhanced_cnn'])
elif model_name == 'cob_rl':
patterns.extend(['rl', 'rl_agent', 'trading_agent'])
# Search in legacy saved directory first
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
@@ -558,7 +554,7 @@ class ModelManager:
return file_path
# Search in model-specific directories
for model_type in model_patterns.keys():
for model_type in ['cnn', 'dqn', 'rl', 'transformer', 'decision']:
model_dir = self.legacy_checkpoints_dir / model_type
if model_dir.exists():
saved_dir = model_dir / "saved"