""" Models Module Provides model registry and interfaces for the trading system. This module acts as a bridge between the core system and the NN models. """ import logging from typing import Dict, Any, Optional, List from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface logger = logging.getLogger(__name__) class ModelRegistry: """Registry for managing trading models""" def __init__(self): self.models: Dict[str, ModelInterface] = {} self.model_performance: Dict[str, Dict[str, Any]] = {} def register_model(self, name: str, model: ModelInterface): """Register a model in the registry""" self.models[name] = model self.model_performance[name] = { 'correct': 0, 'total': 0, 'accuracy': 0.0, 'last_used': None } logger.info(f"Registered model: {name}") def get_model(self, name: str) -> Optional[ModelInterface]: """Get a model by name""" return self.models.get(name) def get_all_models(self) -> Dict[str, ModelInterface]: """Get all registered models""" return self.models.copy() def update_performance(self, name: str, correct: bool): """Update model performance metrics""" if name in self.model_performance: self.model_performance[name]['total'] += 1 if correct: self.model_performance[name]['correct'] += 1 self.model_performance[name]['accuracy'] = ( self.model_performance[name]['correct'] / self.model_performance[name]['total'] ) def get_best_model(self, model_type: str = None) -> Optional[str]: """Get the best performing model""" if not self.model_performance: return None best_model = None best_accuracy = -1.0 for name, perf in self.model_performance.items(): if model_type and not name.lower().startswith(model_type.lower()): continue if perf['accuracy'] > best_accuracy: best_accuracy = perf['accuracy'] best_model = name return best_model # Global model registry instance _model_registry = ModelRegistry() def get_model_registry() -> ModelRegistry: """Get the global model registry instance""" return _model_registry def register_model(name: str, model: ModelInterface): """Register a model in the global registry""" _model_registry.register_model(name, model) def get_model(name: str) -> Optional[ModelInterface]: """Get a model from the global registry""" return _model_registry.get_model(name) def get_all_models() -> Dict[str, ModelInterface]: """Get all models from the global registry""" return _model_registry.get_all_models() # Export the interfaces __all__ = [ 'ModelRegistry', 'get_model_registry', 'register_model', 'get_model', 'get_all_models', 'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface' ]