Files
gogo2/models.py
Dobromir Popov 20112ed693 linux fixes
2025-08-29 18:26:35 +03:00

99 lines
3.1 KiB
Python

"""
Models Module
Provides model registry and interfaces for the trading system.
This module acts as a bridge between the core system and the NN models.
"""
import logging
from typing import Dict, Any, Optional, List
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
logger = logging.getLogger(__name__)
class ModelRegistry:
"""Registry for managing trading models"""
def __init__(self):
self.models: Dict[str, ModelInterface] = {}
self.model_performance: Dict[str, Dict[str, Any]] = {}
def register_model(self, name: str, model: ModelInterface):
"""Register a model in the registry"""
self.models[name] = model
self.model_performance[name] = {
'correct': 0,
'total': 0,
'accuracy': 0.0,
'last_used': None
}
logger.info(f"Registered model: {name}")
def get_model(self, name: str) -> Optional[ModelInterface]:
"""Get a model by name"""
return self.models.get(name)
def get_all_models(self) -> Dict[str, ModelInterface]:
"""Get all registered models"""
return self.models.copy()
def update_performance(self, name: str, correct: bool):
"""Update model performance metrics"""
if name in self.model_performance:
self.model_performance[name]['total'] += 1
if correct:
self.model_performance[name]['correct'] += 1
self.model_performance[name]['accuracy'] = (
self.model_performance[name]['correct'] /
self.model_performance[name]['total']
)
def get_best_model(self, model_type: str = None) -> Optional[str]:
"""Get the best performing model"""
if not self.model_performance:
return None
best_model = None
best_accuracy = -1.0
for name, perf in self.model_performance.items():
if model_type and not name.lower().startswith(model_type.lower()):
continue
if perf['accuracy'] > best_accuracy:
best_accuracy = perf['accuracy']
best_model = name
return best_model
# Global model registry instance
_model_registry = ModelRegistry()
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance"""
return _model_registry
def register_model(name: str, model: ModelInterface):
"""Register a model in the global registry"""
_model_registry.register_model(name, model)
def get_model(name: str) -> Optional[ModelInterface]:
"""Get a model from the global registry"""
return _model_registry.get_model(name)
def get_all_models() -> Dict[str, ModelInterface]:
"""Get all models from the global registry"""
return _model_registry.get_all_models()
# Export the interfaces
__all__ = [
'ModelRegistry',
'get_model_registry',
'register_model',
'get_model',
'get_all_models',
'ModelInterface',
'CNNModelInterface',
'RLAgentInterface',
'ExtremaTrainerInterface'
]