model selector

This commit is contained in:
Dobromir Popov
2025-09-08 14:53:46 +03:00
parent e61536e43d
commit 32d54f0604
4 changed files with 733 additions and 10 deletions

View File

@@ -1820,14 +1820,19 @@ class TradingOrchestrator:
def start_enhanced_training(self):
"""Start the enhanced real-time training system"""
try:
if not self.training_enabled or not getattr(self, 'training_manager', None):
if not self.training_enabled or not self.enhanced_training_system:
logger.warning("Enhanced training system not available")
return False
self.training_manager.start()
logger.info("Enhanced real-time training started")
return True
# Check if the enhanced training system has a start_training method
if hasattr(self.enhanced_training_system, 'start_training'):
self.enhanced_training_system.start_training()
logger.info("Enhanced real-time training started")
return True
else:
logger.warning("Enhanced training system does not have start_training method")
return False
except Exception as e:
logger.error(f"Error starting enhanced training: {e}")
return False
@@ -1835,12 +1840,12 @@ class TradingOrchestrator:
def stop_enhanced_training(self):
"""Stop the enhanced real-time training system"""
try:
if getattr(self, 'training_manager', None):
self.training_manager.stop()
if self.enhanced_training_system and hasattr(self.enhanced_training_system, 'stop_training'):
self.enhanced_training_system.stop_training()
logger.info("Enhanced real-time training stopped")
return True
return False
except Exception as e:
logger.error(f"Error stopping enhanced training: {e}")
return False

View File

@@ -87,10 +87,20 @@ def main():
os.environ['ENABLE_NN_MODELS'] = '1'
try:
# Model Selection at Startup
logger.info("Performing intelligent model selection...")
try:
from utils.model_selector import select_and_load_best_models
selected_models, loaded_models = select_and_load_best_models()
logger.info(f"Selected {len(selected_models)} model types, loaded {len(loaded_models)} models")
except Exception as e:
logger.warning(f"Model selection failed, using defaults: {e}")
selected_models, loaded_models = {}, {}
# Create data provider
logger.info("Initializing data provider...")
data_provider = DataProvider(symbols=['ETH/USDT', 'BTC/USDT'])
# Create orchestrator (with safe CNN handling)
logger.info("Initializing trading orchestrator...")
orchestrator = create_safe_orchestrator()

344
test_model_audit.py Normal file
View File

@@ -0,0 +1,344 @@
#!/usr/bin/env python3
"""
Model Loading/Saving Audit Test
This script tests the model registry and saving/loading mechanisms
to identify any issues and provide recommendations.
"""
import os
import sys
import logging
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SimpleTestModel(nn.Module):
"""Simple neural network for testing"""
def __init__(self, input_size=10, hidden_size=32, output_size=2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
def forward(self, x):
return self.net(x)
def test_model_registry():
"""Test the model registry functionality"""
logger.info("=== MODEL REGISTRY AUDIT ===")
registry = get_model_registry()
logger.info(f"Registry base directory: {registry.base_dir}")
logger.info(f"Registry metadata file: {registry.metadata_file}")
# Check existing models
existing_models = registry.list_models()
logger.info(f"Existing models: {existing_models}")
# Test model creation and saving
logger.info("Creating test model...")
test_model = SimpleTestModel()
# Generate some fake training data
test_input = torch.randn(32, 10)
test_output = test_model(test_input)
logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}")
# Test saving with different methods
logger.info("Testing model saving...")
# Test 1: Save with unified registry
success = save_model(
model=test_model,
model_name="audit_test_model",
model_type="cnn",
metadata={
"test_type": "registry_audit",
"created_at": datetime.now().isoformat(),
"input_shape": list(test_input.shape),
"output_shape": list(test_output.shape)
}
)
if success:
logger.info("✅ Model saved successfully with unified registry")
else:
logger.error("❌ Failed to save model with unified registry")
# Test 2: Load model back
logger.info("Testing model loading...")
loaded_model = load_model("audit_test_model", "cnn")
if loaded_model is not None:
logger.info("✅ Model loaded successfully")
# Test if loaded model has proper structure
if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict):
state_dict = loaded_model.state_dict()
logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}")
# Check if we can create a new instance and load the state
fresh_model = SimpleTestModel()
try:
fresh_model.load_state_dict(state_dict)
test_output_loaded = fresh_model(test_input)
logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}")
# Compare outputs (should be identical)
if torch.allclose(test_output, test_output_loaded, atol=1e-6):
logger.info("✅ Loaded model produces identical outputs")
else:
logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)")
except Exception as e:
logger.warning(f"Could not test loaded model: {e}")
else:
logger.warning("Loaded model does not have proper structure")
else:
logger.error("❌ Failed to load model")
# Test 3: Save checkpoint
logger.info("Testing checkpoint saving...")
checkpoint_success = save_checkpoint(
model=test_model,
model_name="audit_test_model",
model_type="cnn",
performance_score=0.85,
metadata={
"checkpoint_test": True,
"performance_metric": "accuracy",
"epoch": 1
}
)
if checkpoint_success:
logger.info("✅ Checkpoint saved successfully")
else:
logger.error("❌ Failed to save checkpoint")
# Check registry metadata after operations
logger.info("Checking registry metadata after operations...")
updated_models = registry.list_models()
logger.info(f"Updated models: {updated_models}")
# Check file system
logger.info("Checking file system...")
models_dir = Path("models")
if models_dir.exists():
logger.info(f"Models directory contents:")
for item in models_dir.rglob("*"):
if item.is_file():
logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)")
return {
"registry_save_success": success,
"registry_load_success": loaded_model is not None,
"checkpoint_success": checkpoint_success,
"existing_models": existing_models,
"updated_models": updated_models
}
def audit_model_metadata():
"""Audit the model metadata structure"""
logger.info("=== MODEL METADATA AUDIT ===")
registry = get_model_registry()
# Check metadata structure
metadata = registry.metadata
logger.info(f"Metadata keys: {list(metadata.keys())}")
if 'models' in metadata:
models = metadata['models']
logger.info(f"Number of registered models: {len(models)}")
for model_name, model_data in models.items():
logger.info(f"Model '{model_name}':")
logger.info(f" - Type: {model_data.get('type', 'unknown')}")
logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}")
logger.info(f" - Save count: {model_data.get('save_count', 0)}")
logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}")
logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}")
if 'last_updated' in metadata:
logger.info(f"Last metadata update: {metadata['last_updated']}")
return metadata
def analyze_model_files():
"""Analyze the model files on disk"""
logger.info("=== MODEL FILES ANALYSIS ===")
models_dir = Path("models")
if not models_dir.exists():
logger.error("Models directory does not exist")
return {}
analysis = {
'total_files': 0,
'total_size': 0,
'by_type': {},
'by_model': {},
'orphaned_files': [],
'missing_files': []
}
# Analyze all .pt files
for pt_file in models_dir.rglob("*.pt"):
analysis['total_files'] += 1
analysis['total_size'] += pt_file.stat().st_size
# Categorize by type
parts = pt_file.parts
model_type = "unknown"
if "cnn" in parts:
model_type = "cnn"
elif "dqn" in parts:
model_type = "dqn"
elif "transformer" in parts:
model_type = "transformer"
elif "hybrid" in parts:
model_type = "hybrid"
if model_type not in analysis['by_type']:
analysis['by_type'][model_type] = []
analysis['by_type'][model_type].append(str(pt_file))
# Try to extract model name
filename = pt_file.name
if "_latest" in filename:
model_name = filename.replace("_latest.pt", "")
elif "_" in filename:
# Extract timestamp-based names
parts = filename.split("_")
if len(parts) >= 2:
model_name = "_".join(parts[:-1]) # Everything except timestamp
else:
model_name = filename.replace(".pt", "")
else:
model_name = filename.replace(".pt", "")
if model_name not in analysis['by_model']:
analysis['by_model'][model_name] = []
analysis['by_model'][model_name].append(str(pt_file))
logger.info(f"Total model files: {analysis['total_files']}")
logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB")
logger.info("Files by type:")
for model_type, files in analysis['by_type'].items():
logger.info(f" {model_type}: {len(files)} files")
logger.info("Files by model:")
for model_name, files in analysis['by_model'].items():
logger.info(f" {model_name}: {len(files)} files")
return analysis
def recommend_best_model_selection():
"""Provide recommendations for best model selection at startup"""
logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===")
registry = get_model_registry()
models = registry.list_models()
recommendations = {
'startup_strategy': 'hybrid',
'fallback_models': [],
'performance_criteria': [],
'metadata_requirements': []
}
if models:
logger.info("Available models for selection:")
# Analyze each model type
for model_name, model_info in models.items():
model_type = model_info.get('type', 'unknown')
logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}")
# Check if checkpoints exist
if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0:
logger.info(f" - Has {model_info['checkpoint_count']} checkpoints")
recommendations['fallback_models'].append(model_name)
# Recommendations
logger.info("RECOMMENDATIONS:")
logger.info("1. Startup Strategy:")
logger.info(" - Try to load latest model for each type")
logger.info(" - Fall back to checkpoints if latest model fails")
logger.info(" - Use fallback to basic/default model if all else fails")
logger.info("2. Performance-based Selection:")
logger.info(" - For models with checkpoints, select highest performance_score")
logger.info(" - Track model age and prefer recently trained models")
logger.info(" - Implement model validation on startup")
logger.info("3. Metadata Requirements:")
logger.info(" - Store performance metrics in metadata")
logger.info(" - Track training data quality and size")
logger.info(" - Include model validation results")
else:
logger.info("No models registered - system will need initial training")
logger.info("RECOMMENDATION: Implement default model initialization")
return recommendations
def main():
"""Main audit function"""
logger.info("Starting Model Loading/Saving Audit")
logger.info("=" * 60)
try:
# Test model registry
registry_results = test_model_registry()
logger.info("-" * 40)
# Audit metadata
metadata = audit_model_metadata()
logger.info("-" * 40)
# Analyze files
file_analysis = analyze_model_files()
logger.info("-" * 40)
# Recommendations
recommendations = recommend_best_model_selection()
logger.info("-" * 40)
# Summary
logger.info("=== AUDIT SUMMARY ===")
logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}")
logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}")
logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}")
logger.info(f"Total model files: {file_analysis.get('total_files', 0)}")
logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}")
logger.info("Audit completed successfully!")
except Exception as e:
logger.error(f"Audit failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

364
utils/model_selector.py Normal file
View File

@@ -0,0 +1,364 @@
#!/usr/bin/env python3
"""
Best Model Selection for Startup
This module provides intelligent model selection logic for choosing the best
available models at system startup based on various criteria.
"""
import os
import logging
import json
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime, timedelta
import torch
from utils.model_registry import get_model_registry, load_model, load_best_checkpoint
logger = logging.getLogger(__name__)
class ModelSelector:
"""
Intelligent model selector for startup and runtime model selection.
"""
def __init__(self):
"""Initialize the model selector"""
self.registry = get_model_registry()
self.selection_criteria = {
'max_age_days': 30, # Don't use models older than 30 days
'min_performance_score': 0.5, # Minimum acceptable performance
'prefer_recent': True, # Prefer recently trained models
'fallback_to_any': True # Use any model if no good ones found
}
logger.info("Model Selector initialized")
def select_best_models_for_startup(self) -> Dict[str, Dict[str, Any]]:
"""
Select the best available models for each type at startup.
Returns:
Dictionary mapping model types to selected model info
"""
logger.info("Selecting best models for startup...")
available_models = self.registry.list_models()
selected_models = {}
# Group models by type
models_by_type = {}
for model_name, model_info in available_models.items():
model_type = model_info.get('type', 'unknown')
if model_type not in models_by_type:
models_by_type[model_type] = []
models_by_type[model_type].append((model_name, model_info))
# Select best model for each type
for model_type, models in models_by_type.items():
if not models:
continue
logger.info(f"Selecting best {model_type} model from {len(models)} candidates")
best_model = self._select_best_model_for_type(models, model_type)
if best_model:
selected_models[model_type] = best_model
logger.info(f"Selected {best_model['name']} for {model_type}")
else:
logger.warning(f"No suitable {model_type} model found")
return selected_models
def _select_best_model_for_type(self, models: List[Tuple[str, Dict]], model_type: str) -> Optional[Dict[str, Any]]:
"""
Select the best model for a specific type.
Args:
models: List of (name, info) tuples
model_type: Type of model to select
Returns:
Selected model information or None
"""
if not models:
return None
candidates = []
for model_name, model_info in models:
# Check if model meets basic criteria
if not self._meets_basic_criteria(model_info):
continue
# Calculate selection score
score = self._calculate_selection_score(model_name, model_info, model_type)
candidates.append({
'name': model_name,
'info': model_info,
'score': score,
'has_checkpoints': model_info.get('checkpoint_count', 0) > 0
})
if not candidates:
if self.selection_criteria['fallback_to_any']:
# Fallback to most recent model
logger.info(f"No good {model_type} candidates, using fallback")
return self._select_fallback_model(models)
return None
# Sort by score (highest first)
candidates.sort(key=lambda x: x['score'], reverse=True)
best_candidate = candidates[0]
# Try to load the model to verify it's working
if self._verify_model_loadable(best_candidate['name'], model_type):
return {
'name': best_candidate['name'],
'type': model_type,
'info': best_candidate['info'],
'score': best_candidate['score'],
'selection_reason': self._get_selection_reason(best_candidate),
'verified': True
}
else:
logger.warning(f"Selected model {best_candidate['name']} failed verification")
# Try next candidate
if len(candidates) > 1:
next_candidate = candidates[1]
if self._verify_model_loadable(next_candidate['name'], model_type):
return {
'name': next_candidate['name'],
'type': model_type,
'info': next_candidate['info'],
'score': next_candidate['score'],
'selection_reason': 'fallback_after_verification_failure',
'verified': True
}
return None
def _meets_basic_criteria(self, model_info: Dict[str, Any]) -> bool:
"""Check if model meets basic selection criteria"""
# Check age
last_saved = model_info.get('last_saved')
if last_saved:
try:
# Parse timestamp (format: YYYYMMDD_HHMMSS)
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
age_days = (datetime.now() - model_date).days
if age_days > self.selection_criteria['max_age_days']:
return False
except ValueError:
logger.warning(f"Could not parse timestamp: {last_saved}")
return True
def _calculate_selection_score(self, model_name: str, model_info: Dict[str, Any], model_type: str) -> float:
"""Calculate selection score for a model"""
score = 0.0
# Base score from recency (newer is better)
last_saved = model_info.get('last_saved')
if last_saved:
try:
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
days_old = (datetime.now() - model_date).days
recency_score = max(0, 30 - days_old) / 30.0 # 0-1 score for last 30 days
score += recency_score * 0.4
except ValueError:
pass
# Score from checkpoints (having checkpoints is good)
checkpoint_count = model_info.get('checkpoint_count', 0)
if checkpoint_count > 0:
checkpoint_score = min(checkpoint_count / 10.0, 1.0) # Max score for 10+ checkpoints
score += checkpoint_score * 0.3
# Score from save count (more saves might indicate stability)
save_count = model_info.get('save_count', 0)
if save_count > 1:
stability_score = min(save_count / 5.0, 1.0) # Max score for 5+ saves
score += stability_score * 0.3
return score
def _select_fallback_model(self, models: List[Tuple[str, Dict]]) -> Optional[Dict[str, Any]]:
"""Select a fallback model when no good candidates found"""
if not models:
return None
# Sort by recency
sorted_models = sorted(models, key=lambda x: x[1].get('last_saved', ''), reverse=True)
model_name, model_info = sorted_models[0]
return {
'name': model_name,
'type': model_info.get('type', 'unknown'),
'info': model_info,
'score': 0.0,
'selection_reason': 'fallback_most_recent',
'verified': False
}
def _verify_model_loadable(self, model_name: str, model_type: str) -> bool:
"""Verify that a model can be loaded successfully"""
try:
model = load_model(model_name, model_type)
return model is not None
except Exception as e:
logger.warning(f"Model verification failed for {model_name}: {e}")
return False
def _get_selection_reason(self, candidate: Dict[str, Any]) -> str:
"""Get human-readable selection reason"""
reasons = []
if candidate.get('has_checkpoints'):
reasons.append("has_checkpoints")
score = candidate.get('score', 0)
if score > 0.8:
reasons.append("high_score")
elif score > 0.6:
reasons.append("good_score")
else:
reasons.append("acceptable_score")
return ", ".join(reasons) if reasons else "default_selection"
def load_selected_models(self, selected_models: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Load the selected models into memory.
Args:
selected_models: Dictionary from select_best_models_for_startup
Returns:
Dictionary of loaded models
"""
loaded_models = {}
for model_type, selection_info in selected_models.items():
model_name = selection_info['name']
logger.info(f"Loading {model_type} model: {model_name}")
try:
# Try to load best checkpoint first if available
if selection_info['info'].get('checkpoint_count', 0) > 0:
checkpoint_result = load_best_checkpoint(model_name, model_type)
if checkpoint_result:
checkpoint_path, checkpoint_data = checkpoint_result
loaded_models[model_type] = {
'model': None, # Would need proper model class instantiation
'checkpoint_data': checkpoint_data,
'source': 'checkpoint',
'path': checkpoint_path,
'performance_score': checkpoint_data.get('performance_score', 0)
}
logger.info(f"Loaded {model_type} from checkpoint: {checkpoint_path}")
continue
# Fall back to regular model loading
model = load_model(model_name, model_type)
if model:
loaded_models[model_type] = {
'model': model,
'source': 'latest',
'path': selection_info['info'].get('latest_path'),
'performance_score': None
}
logger.info(f"Loaded {model_type} from latest: {model_name}")
else:
logger.error(f"Failed to load {model_type} model: {model_name}")
except Exception as e:
logger.error(f"Error loading {model_type} model {model_name}: {e}")
return loaded_models
def get_startup_report(self, selected_models: Dict[str, Dict[str, Any]],
loaded_models: Dict[str, Any]) -> str:
"""Generate a startup report"""
report_lines = [
"=" * 60,
"MODEL STARTUP SELECTION REPORT",
"=" * 60,
f"Selection Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
""
]
if selected_models:
report_lines.append("SELECTED MODELS:")
for model_type, selection_info in selected_models.items():
report_lines.append(f" {model_type.upper()}: {selection_info['name']}")
report_lines.append(f" - Score: {selection_info.get('score', 0):.3f}")
report_lines.append(f" - Reason: {selection_info.get('selection_reason', 'unknown')}")
report_lines.append(f" - Verified: {selection_info.get('verified', False)}")
report_lines.append(f" - Last Saved: {selection_info['info'].get('last_saved', 'unknown')}")
report_lines.append("")
else:
report_lines.append("NO MODELS SELECTED")
report_lines.append("")
if loaded_models:
report_lines.append("LOADED MODELS:")
for model_type, model_info in loaded_models.items():
source = model_info.get('source', 'unknown')
report_lines.append(f" {model_type.upper()}: Loaded from {source}")
if 'performance_score' in model_info and model_info['performance_score'] is not None:
report_lines.append(f" - Performance Score: {model_info['performance_score']:.3f}")
report_lines.append("")
else:
report_lines.append("NO MODELS LOADED")
report_lines.append("")
# Add summary statistics
total_models = len(self.registry.list_models())
selected_count = len(selected_models)
loaded_count = len(loaded_models)
report_lines.extend([
"SUMMARY STATISTICS:",
f" Total Available Models: {total_models}",
f" Models Selected: {selected_count}",
f" Models Loaded: {loaded_count}",
"=" * 60
])
return "\n".join(report_lines)
# Global instance
_model_selector = None
def get_model_selector() -> ModelSelector:
"""Get the global model selector instance"""
global _model_selector
if _model_selector is None:
_model_selector = ModelSelector()
return _model_selector
def select_and_load_best_models() -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
"""
Convenience function to select and load best models for startup.
Returns:
Tuple of (selected_models_info, loaded_models)
"""
selector = get_model_selector()
# Select best models
selected_models = selector.select_best_models_for_startup()
# Load selected models
loaded_models = selector.load_selected_models(selected_models)
# Generate and log report
report = selector.get_startup_report(selected_models, loaded_models)
logger.info("Model Startup Report:\n" + report)
return selected_models, loaded_models