refactoring

This commit is contained in:
Dobromir Popov
2025-09-08 23:57:21 +03:00
parent 98ebbe5089
commit c3a94600c8
50 changed files with 856 additions and 1302 deletions

View File

@@ -24,7 +24,7 @@ import json
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
logger = logging.getLogger(__name__)

View File

@@ -21,7 +21,7 @@ import pandas as pd
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
logger = logging.getLogger(__name__)

View File

@@ -32,9 +32,9 @@ import torch.optim as optim
from .config import get_config
from .data_provider import DataProvider
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
from NN.training.model_manager import create_model_manager, ModelManager, ModelMetrics, CheckpointMetadata
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
# Import COB integration for real-time market microstructure data
@@ -92,12 +92,12 @@ class TradingOrchestrator:
Includes EnhancedRealtimeTrainingSystem for continuous learning
"""
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_manager: Optional[ModelManager] = None):
"""Initialize the enhanced orchestrator with full ML capabilities"""
self.config = get_config()
self.data_provider = data_provider or DataProvider()
self.universal_adapter = UniversalDataAdapter(self.data_provider)
self.model_registry = model_registry or get_model_registry()
self.model_manager = model_manager or create_model_manager()
self.enhanced_rl_training = enhanced_rl_training
# Configuration - AGGRESSIVE for more training data
@@ -114,14 +114,12 @@ class TradingOrchestrator:
self.current_positions: Dict[str, Dict] = {} # {symbol: {side, size, entry_price, entry_time, pnl}}
self.trading_executor = None # Will be set by dashboard or external system
# Dynamic weights (will be adapted based on performance)
self.model_weights: Dict[str, float] = {} # {model_name: weight}
self._initialize_default_weights()
# Model management delegated to unified ModelManager
# self.model_weights and self.model_performance are now handled by self.model_manager
# State tracking
self.last_decision_time: Dict[str, datetime] = {} # {symbol: datetime}
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
# Model prediction tracking for dashboard visualization
self.recent_dqn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent DQN predictions
@@ -228,7 +226,7 @@ class TradingOrchestrator:
try:
self.rl_agent.load_best_checkpoint() # This loads the state into the model
# Check if we have checkpoints available
from utils.checkpoint_manager import load_best_checkpoint
from NN.training.model_manager import load_best_checkpoint
result = load_best_checkpoint("dqn_agent")
if result:
file_path, metadata = result
@@ -268,7 +266,7 @@ class TradingOrchestrator:
# Load best checkpoint and capture initial state
checkpoint_loaded = False
try:
from utils.checkpoint_manager import load_best_checkpoint
from NN.training.model_manager import load_best_checkpoint
result = load_best_checkpoint("enhanced_cnn")
if result:
file_path, metadata = result
@@ -374,7 +372,7 @@ class TradingOrchestrator:
# Load best checkpoint
checkpoint_loaded = False
try:
from utils.checkpoint_manager import load_best_checkpoint
from NN.training.model_manager import load_best_checkpoint
result = load_best_checkpoint("transformer")
if result:
file_path, metadata = result
@@ -408,7 +406,7 @@ class TradingOrchestrator:
# Load best checkpoint
checkpoint_loaded = False
try:
from utils.checkpoint_manager import load_best_checkpoint
from NN.training.model_manager import load_best_checkpoint
result = load_best_checkpoint("decision")
if result:
file_path, metadata = result
@@ -455,7 +453,7 @@ class TradingOrchestrator:
if self.rl_agent:
try:
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
self.register_model(rl_interface, weight=0.3)
# RL model registration handled by ModelManager
logger.info("RL Agent registered successfully")
except Exception as e:
logger.error(f"Failed to register RL Agent: {e}")
@@ -464,7 +462,7 @@ class TradingOrchestrator:
if self.cnn_model:
try:
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
self.register_model(cnn_interface, weight=0.4)
# CNN model registration handled by ModelManager
logger.info("CNN Model registered successfully")
except Exception as e:
logger.error(f"Failed to register CNN Model: {e}")
@@ -490,7 +488,7 @@ class TradingOrchestrator:
return 30.0 # MB
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
self.register_model(extrema_interface, weight=0.15) # Lower weight for extrema signals
# Extrema model registration handled by ModelManager
logger.info("Extrema Trainer registered successfully")
except Exception as e:
logger.error(f"Failed to register Extrema Trainer: {e}")
@@ -521,7 +519,7 @@ class TradingOrchestrator:
return 60.0 # MB estimate for transformer
transformer_interface = TransformerModelInterface(self.transformer_model, self.transformer_trainer, name="transformer")
self.register_model(transformer_interface, weight=0.2)
# Transformer model registration handled by ModelManager
logger.info("Transformer Model registered successfully")
except Exception as e:
logger.error(f"Failed to register Transformer Model: {e}")
@@ -547,14 +545,14 @@ class TradingOrchestrator:
return 40.0 # MB estimate for decision model
decision_interface = DecisionModelInterface(self.decision_model, name="decision")
self.register_model(decision_interface, weight=0.15)
# Decision model registration handled by ModelManager
logger.info("Decision Fusion Model registered successfully")
except Exception as e:
logger.error(f"Failed to register Decision Fusion Model: {e}")
# Normalize weights after all registrations
self._normalize_weights()
logger.info(f"Current model weights: {self.model_weights}")
# Model weight normalization handled by ModelManager
# Model weights now handled by ModelManager
logger.info("Model management delegated to unified ModelManager")
logger.info("COB_RL model removed - cleaner architecture pending COB data quality fixes")
except Exception as e:
@@ -627,7 +625,7 @@ class TradingOrchestrator:
state = {
'model_states': {k: {sk: sv for sk, sv in v.items() if sk != 'checkpoint_loaded'} # Exclude non-serializable
for k, v in self.model_states.items()},
'model_weights': self.model_weights,
# 'model_weights': self.model_weights, # Now handled by ModelManager
'last_trained_symbols': list(self.last_trained_symbols.keys())
}
save_path = os.path.join(self.config.paths.get('checkpoint_dir', './models/saved'), 'orchestrator_state.json')
@@ -644,7 +642,7 @@ class TradingOrchestrator:
with open(save_path, 'r') as f:
state = json.load(f)
self.model_states.update(state.get('model_states', {}))
self.model_weights = state.get('model_weights', self.model_weights)
# self.model_weights = state.get('model_weights', {}) # Now handled by ModelManager
self.last_trained_symbols = {s: datetime.now() for s in state.get('last_trained_symbols', [])} # Restore with current time
logger.info(f"Orchestrator state loaded from {save_path}")
except Exception as e:
@@ -948,62 +946,10 @@ class TradingOrchestrator:
return np.array(padded_features[-sequence_length:]).astype(np.float32) # Ensure correct length
def _initialize_default_weights(self):
"""Initialize default model weights from config"""
self.model_weights = {
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
'RL': self.config.orchestrator.get('rl_weight', 0.3)
}
# Model management methods removed - all handled by unified ModelManager
# Use self.model_manager for all model operations
def register_model(self, model: ModelInterface, weight: float = None) -> bool:
"""Register a new model with the orchestrator"""
try:
# Register with model registry
if not self.model_registry.register_model(model):
return False
# Set weight
if weight is not None:
self.model_weights[model.name] = weight
elif model.name not in self.model_weights:
self.model_weights[model.name] = 0.1 # Default low weight for new models
# Initialize performance tracking
if model.name not in self.model_performance:
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
self._normalize_weights()
return True
except Exception as e:
logger.error(f"Error registering model {model.name}: {e}")
return False
def unregister_model(self, model_name: str) -> bool:
"""Unregister a model"""
try:
if self.model_registry.unregister_model(model_name):
if model_name in self.model_weights:
del self.model_weights[model_name]
if model_name in self.model_performance:
del self.model_performance[model_name]
self._normalize_weights()
logger.info(f"Unregistered {model_name} model")
return True
return False
except Exception as e:
logger.error(f"Error unregistering model {model_name}: {e}")
return False
def _normalize_weights(self):
"""Normalize model weights to sum to 1.0"""
total_weight = sum(self.model_weights.values())
if total_weight > 0:
for model_name in self.model_weights:
self.model_weights[model_name] /= total_weight
# Weight normalization removed - handled by ModelManager
def add_decision_callback(self, callback):
"""Add a callback function to be called when decisions are made"""
@@ -1066,9 +1012,7 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error in decision callback: {e}")
# Clean up memory periodically
if len(self.recent_decisions[symbol]) % 50 == 0:
self.model_registry.cleanup_all_models()
# Model cleanup handled by ModelManager
return decision
@@ -1077,38 +1021,17 @@ class TradingOrchestrator:
return None
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models"""
"""Get predictions from all registered models via ModelManager"""
predictions = []
for model_name, model in self.model_registry.models.items():
try:
if isinstance(model, CNNModelInterface):
# Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions)
elif isinstance(model, RLAgentInterface):
# Get RL prediction
rl_prediction = await self._get_rl_prediction(model, symbol)
if rl_prediction:
predictions.append(rl_prediction)
elif isinstance(model, COBRLModelInterface):
# Get COB RL prediction
cob_prediction = await self._get_cob_rl_prediction(model, symbol)
if cob_prediction:
predictions.append(cob_prediction)
else:
# Generic model interface
generic_prediction = await self._get_generic_prediction(model, symbol)
if generic_prediction:
predictions.append(generic_prediction)
except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}")
continue
# This method now delegates to ModelManager for model iteration
# The actual model prediction logic has been moved to individual methods
# that are called by the ModelManager
logger.debug(f"Getting predictions for {symbol} - model management handled by ModelManager")
# For now, return empty list as this method needs to be restructured
# to work with the new ModelManager architecture
return predictions
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
@@ -1454,7 +1377,7 @@ class TradingOrchestrator:
try:
reasoning = {
'predictions': len(predictions),
'weights': self.model_weights.copy(),
# 'weights': {}, # Now handled by ModelManager
'models_used': [pred.model_name for pred in predictions]
}
@@ -1468,7 +1391,7 @@ class TradingOrchestrator:
# Process all predictions
for pred in predictions:
# Get model weight
model_weight = self.model_weights.get(pred.model_name, 0.1)
model_weight = 0.1 # Default weight, now managed by ModelManager
# Weight by confidence and timeframe importance
timeframe_weight = self._get_timeframe_weight(pred.timeframe)
@@ -1518,7 +1441,7 @@ class TradingOrchestrator:
# Get memory usage stats
try:
memory_usage = self.model_registry.get_memory_stats() if hasattr(self.model_registry, 'get_memory_stats') else {}
memory_usage = self.model_manager.get_storage_stats() if hasattr(self.model_manager, 'get_storage_stats') else {}
except Exception:
memory_usage = {}
@@ -1571,31 +1494,8 @@ class TradingOrchestrator:
}
return weights.get(timeframe, 0.5)
def update_model_performance(self, model_name: str, was_correct: bool):
"""Update performance tracking for a model"""
if model_name in self.model_performance:
self.model_performance[model_name]['total'] += 1
if was_correct:
self.model_performance[model_name]['correct'] += 1
# Update accuracy
total = self.model_performance[model_name]['total']
correct = self.model_performance[model_name]['correct']
self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0
def adapt_weights(self):
"""Dynamically adapt model weights based on performance"""
try:
for model_name, performance in self.model_performance.items():
if performance['total'] > 0:
# Adjust weight based on relative performance
accuracy = performance['correct'] / performance['total']
self.model_weights[model_name] = accuracy
logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}")
except Exception as e:
logger.error(f"Error adapting weights: {e}")
# Model performance and weight adaptation removed - handled by ModelManager
# Use self.model_manager for all model performance tracking
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
"""Get recent decisions for a symbol"""
@@ -1606,8 +1506,8 @@ class TradingOrchestrator:
def get_performance_metrics(self) -> Dict[str, Any]:
"""Get performance metrics for the orchestrator"""
return {
'model_performance': self.model_performance.copy(),
'weights': self.model_weights.copy(),
# 'model_performance': {}, # Now handled by ModelManager
# 'weights': {}, # Now handled by ModelManager
'configuration': {
'confidence_threshold': self.confidence_threshold,
'decision_frequency': self.decision_frequency
@@ -1630,7 +1530,7 @@ class TradingOrchestrator:
current_time = time.time()
cache_expiry = 60 # seconds
from utils.checkpoint_manager import load_best_checkpoint
from NN.training.model_manager import load_best_checkpoint
# Update each model with REAL checkpoint data (cached)
# Note: COB_RL removed - functionality integrated into Enhanced CNN
@@ -1872,7 +1772,7 @@ class TradingOrchestrator:
'decision_fusion_enabled': self.decision_fusion_enabled,
'symbols_tracking': len(self.symbols),
'recent_decisions_count': sum(len(decisions) for decisions in self.recent_decisions.values()),
'model_weights': self.model_weights.copy(),
# 'model_weights': {}, # Now handled by ModelManager
'realtime_processing': self.realtime_processing
}

View File

@@ -114,10 +114,10 @@ class RealtimeRLCOBTrader:
self.min_confidence_threshold = min_confidence_threshold
self.required_confident_predictions = required_confident_predictions
# Initialize CheckpointManager (either provided or get global instance)
# Initialize ModelManager (either provided or get global instance)
if checkpoint_manager is None:
from utils.checkpoint_manager import get_checkpoint_manager
self.checkpoint_manager = get_checkpoint_manager()
from NN.training.model_manager import create_model_manager
self.checkpoint_manager = create_model_manager()
else:
self.checkpoint_manager = checkpoint_manager