refactoring
This commit is contained in:
@@ -32,9 +32,9 @@ import torch.optim as optim
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
||||
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
||||
from NN.training.model_manager import create_model_manager, ModelManager, ModelMetrics, CheckpointMetadata
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file
|
||||
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
||||
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
@@ -92,12 +92,12 @@ class TradingOrchestrator:
|
||||
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_manager: Optional[ModelManager] = None):
|
||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||||
self.model_registry = model_registry or get_model_registry()
|
||||
self.model_manager = model_manager or create_model_manager()
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
|
||||
# Configuration - AGGRESSIVE for more training data
|
||||
@@ -114,14 +114,12 @@ class TradingOrchestrator:
|
||||
self.current_positions: Dict[str, Dict] = {} # {symbol: {side, size, entry_price, entry_time, pnl}}
|
||||
self.trading_executor = None # Will be set by dashboard or external system
|
||||
|
||||
# Dynamic weights (will be adapted based on performance)
|
||||
self.model_weights: Dict[str, float] = {} # {model_name: weight}
|
||||
self._initialize_default_weights()
|
||||
|
||||
# Model management delegated to unified ModelManager
|
||||
# self.model_weights and self.model_performance are now handled by self.model_manager
|
||||
|
||||
# State tracking
|
||||
self.last_decision_time: Dict[str, datetime] = {} # {symbol: datetime}
|
||||
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
|
||||
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
||||
|
||||
# Model prediction tracking for dashboard visualization
|
||||
self.recent_dqn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent DQN predictions
|
||||
@@ -228,7 +226,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||||
# Check if we have checkpoints available
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("dqn_agent")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
@@ -268,7 +266,7 @@ class TradingOrchestrator:
|
||||
# Load best checkpoint and capture initial state
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("enhanced_cnn")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
@@ -374,7 +372,7 @@ class TradingOrchestrator:
|
||||
# Load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("transformer")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
@@ -408,7 +406,7 @@ class TradingOrchestrator:
|
||||
# Load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("decision")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
@@ -455,7 +453,7 @@ class TradingOrchestrator:
|
||||
if self.rl_agent:
|
||||
try:
|
||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||
self.register_model(rl_interface, weight=0.3)
|
||||
# RL model registration handled by ModelManager
|
||||
logger.info("RL Agent registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register RL Agent: {e}")
|
||||
@@ -464,7 +462,7 @@ class TradingOrchestrator:
|
||||
if self.cnn_model:
|
||||
try:
|
||||
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
||||
self.register_model(cnn_interface, weight=0.4)
|
||||
# CNN model registration handled by ModelManager
|
||||
logger.info("CNN Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register CNN Model: {e}")
|
||||
@@ -490,7 +488,7 @@ class TradingOrchestrator:
|
||||
return 30.0 # MB
|
||||
|
||||
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
||||
self.register_model(extrema_interface, weight=0.15) # Lower weight for extrema signals
|
||||
# Extrema model registration handled by ModelManager
|
||||
logger.info("Extrema Trainer registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||
@@ -521,7 +519,7 @@ class TradingOrchestrator:
|
||||
return 60.0 # MB estimate for transformer
|
||||
|
||||
transformer_interface = TransformerModelInterface(self.transformer_model, self.transformer_trainer, name="transformer")
|
||||
self.register_model(transformer_interface, weight=0.2)
|
||||
# Transformer model registration handled by ModelManager
|
||||
logger.info("Transformer Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Transformer Model: {e}")
|
||||
@@ -547,14 +545,14 @@ class TradingOrchestrator:
|
||||
return 40.0 # MB estimate for decision model
|
||||
|
||||
decision_interface = DecisionModelInterface(self.decision_model, name="decision")
|
||||
self.register_model(decision_interface, weight=0.15)
|
||||
# Decision model registration handled by ModelManager
|
||||
logger.info("Decision Fusion Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Decision Fusion Model: {e}")
|
||||
|
||||
# Normalize weights after all registrations
|
||||
self._normalize_weights()
|
||||
logger.info(f"Current model weights: {self.model_weights}")
|
||||
# Model weight normalization handled by ModelManager
|
||||
# Model weights now handled by ModelManager
|
||||
logger.info("Model management delegated to unified ModelManager")
|
||||
logger.info("COB_RL model removed - cleaner architecture pending COB data quality fixes")
|
||||
|
||||
except Exception as e:
|
||||
@@ -627,7 +625,7 @@ class TradingOrchestrator:
|
||||
state = {
|
||||
'model_states': {k: {sk: sv for sk, sv in v.items() if sk != 'checkpoint_loaded'} # Exclude non-serializable
|
||||
for k, v in self.model_states.items()},
|
||||
'model_weights': self.model_weights,
|
||||
# 'model_weights': self.model_weights, # Now handled by ModelManager
|
||||
'last_trained_symbols': list(self.last_trained_symbols.keys())
|
||||
}
|
||||
save_path = os.path.join(self.config.paths.get('checkpoint_dir', './models/saved'), 'orchestrator_state.json')
|
||||
@@ -644,7 +642,7 @@ class TradingOrchestrator:
|
||||
with open(save_path, 'r') as f:
|
||||
state = json.load(f)
|
||||
self.model_states.update(state.get('model_states', {}))
|
||||
self.model_weights = state.get('model_weights', self.model_weights)
|
||||
# self.model_weights = state.get('model_weights', {}) # Now handled by ModelManager
|
||||
self.last_trained_symbols = {s: datetime.now() for s in state.get('last_trained_symbols', [])} # Restore with current time
|
||||
logger.info(f"Orchestrator state loaded from {save_path}")
|
||||
except Exception as e:
|
||||
@@ -948,62 +946,10 @@ class TradingOrchestrator:
|
||||
|
||||
return np.array(padded_features[-sequence_length:]).astype(np.float32) # Ensure correct length
|
||||
|
||||
def _initialize_default_weights(self):
|
||||
"""Initialize default model weights from config"""
|
||||
self.model_weights = {
|
||||
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
|
||||
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
||||
}
|
||||
# Model management methods removed - all handled by unified ModelManager
|
||||
# Use self.model_manager for all model operations
|
||||
|
||||
def register_model(self, model: ModelInterface, weight: float = None) -> bool:
|
||||
"""Register a new model with the orchestrator"""
|
||||
try:
|
||||
# Register with model registry
|
||||
if not self.model_registry.register_model(model):
|
||||
return False
|
||||
|
||||
# Set weight
|
||||
if weight is not None:
|
||||
self.model_weights[model.name] = weight
|
||||
elif model.name not in self.model_weights:
|
||||
self.model_weights[model.name] = 0.1 # Default low weight for new models
|
||||
|
||||
# Initialize performance tracking
|
||||
if model.name not in self.model_performance:
|
||||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||||
self._normalize_weights()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering model {model.name}: {e}")
|
||||
return False
|
||||
|
||||
def unregister_model(self, model_name: str) -> bool:
|
||||
"""Unregister a model"""
|
||||
try:
|
||||
if self.model_registry.unregister_model(model_name):
|
||||
if model_name in self.model_weights:
|
||||
del self.model_weights[model_name]
|
||||
if model_name in self.model_performance:
|
||||
del self.model_performance[model_name]
|
||||
|
||||
self._normalize_weights()
|
||||
logger.info(f"Unregistered {model_name} model")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error unregistering model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def _normalize_weights(self):
|
||||
"""Normalize model weights to sum to 1.0"""
|
||||
total_weight = sum(self.model_weights.values())
|
||||
if total_weight > 0:
|
||||
for model_name in self.model_weights:
|
||||
self.model_weights[model_name] /= total_weight
|
||||
# Weight normalization removed - handled by ModelManager
|
||||
|
||||
def add_decision_callback(self, callback):
|
||||
"""Add a callback function to be called when decisions are made"""
|
||||
@@ -1066,9 +1012,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in decision callback: {e}")
|
||||
|
||||
# Clean up memory periodically
|
||||
if len(self.recent_decisions[symbol]) % 50 == 0:
|
||||
self.model_registry.cleanup_all_models()
|
||||
# Model cleanup handled by ModelManager
|
||||
|
||||
return decision
|
||||
|
||||
@@ -1077,38 +1021,17 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from all registered models"""
|
||||
"""Get predictions from all registered models via ModelManager"""
|
||||
predictions = []
|
||||
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions for each timeframe
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||
predictions.extend(cnn_predictions)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
|
||||
elif isinstance(model, COBRLModelInterface):
|
||||
# Get COB RL prediction
|
||||
cob_prediction = await self._get_cob_rl_prediction(model, symbol)
|
||||
if cob_prediction:
|
||||
predictions.append(cob_prediction)
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||||
if generic_prediction:
|
||||
predictions.append(generic_prediction)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# This method now delegates to ModelManager for model iteration
|
||||
# The actual model prediction logic has been moved to individual methods
|
||||
# that are called by the ModelManager
|
||||
|
||||
logger.debug(f"Getting predictions for {symbol} - model management handled by ModelManager")
|
||||
|
||||
# For now, return empty list as this method needs to be restructured
|
||||
# to work with the new ModelManager architecture
|
||||
return predictions
|
||||
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||
@@ -1454,7 +1377,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
reasoning = {
|
||||
'predictions': len(predictions),
|
||||
'weights': self.model_weights.copy(),
|
||||
# 'weights': {}, # Now handled by ModelManager
|
||||
'models_used': [pred.model_name for pred in predictions]
|
||||
}
|
||||
|
||||
@@ -1468,7 +1391,7 @@ class TradingOrchestrator:
|
||||
# Process all predictions
|
||||
for pred in predictions:
|
||||
# Get model weight
|
||||
model_weight = self.model_weights.get(pred.model_name, 0.1)
|
||||
model_weight = 0.1 # Default weight, now managed by ModelManager
|
||||
|
||||
# Weight by confidence and timeframe importance
|
||||
timeframe_weight = self._get_timeframe_weight(pred.timeframe)
|
||||
@@ -1518,7 +1441,7 @@ class TradingOrchestrator:
|
||||
|
||||
# Get memory usage stats
|
||||
try:
|
||||
memory_usage = self.model_registry.get_memory_stats() if hasattr(self.model_registry, 'get_memory_stats') else {}
|
||||
memory_usage = self.model_manager.get_storage_stats() if hasattr(self.model_manager, 'get_storage_stats') else {}
|
||||
except Exception:
|
||||
memory_usage = {}
|
||||
|
||||
@@ -1571,31 +1494,8 @@ class TradingOrchestrator:
|
||||
}
|
||||
return weights.get(timeframe, 0.5)
|
||||
|
||||
def update_model_performance(self, model_name: str, was_correct: bool):
|
||||
"""Update performance tracking for a model"""
|
||||
if model_name in self.model_performance:
|
||||
self.model_performance[model_name]['total'] += 1
|
||||
if was_correct:
|
||||
self.model_performance[model_name]['correct'] += 1
|
||||
|
||||
# Update accuracy
|
||||
total = self.model_performance[model_name]['total']
|
||||
correct = self.model_performance[model_name]['correct']
|
||||
self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0
|
||||
|
||||
def adapt_weights(self):
|
||||
"""Dynamically adapt model weights based on performance"""
|
||||
try:
|
||||
for model_name, performance in self.model_performance.items():
|
||||
if performance['total'] > 0:
|
||||
# Adjust weight based on relative performance
|
||||
accuracy = performance['correct'] / performance['total']
|
||||
self.model_weights[model_name] = accuracy
|
||||
|
||||
logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adapting weights: {e}")
|
||||
# Model performance and weight adaptation removed - handled by ModelManager
|
||||
# Use self.model_manager for all model performance tracking
|
||||
|
||||
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
|
||||
"""Get recent decisions for a symbol"""
|
||||
@@ -1606,8 +1506,8 @@ class TradingOrchestrator:
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics for the orchestrator"""
|
||||
return {
|
||||
'model_performance': self.model_performance.copy(),
|
||||
'weights': self.model_weights.copy(),
|
||||
# 'model_performance': {}, # Now handled by ModelManager
|
||||
# 'weights': {}, # Now handled by ModelManager
|
||||
'configuration': {
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'decision_frequency': self.decision_frequency
|
||||
@@ -1630,7 +1530,7 @@ class TradingOrchestrator:
|
||||
current_time = time.time()
|
||||
cache_expiry = 60 # seconds
|
||||
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
|
||||
# Update each model with REAL checkpoint data (cached)
|
||||
# Note: COB_RL removed - functionality integrated into Enhanced CNN
|
||||
@@ -1872,7 +1772,7 @@ class TradingOrchestrator:
|
||||
'decision_fusion_enabled': self.decision_fusion_enabled,
|
||||
'symbols_tracking': len(self.symbols),
|
||||
'recent_decisions_count': sum(len(decisions) for decisions in self.recent_decisions.values()),
|
||||
'model_weights': self.model_weights.copy(),
|
||||
# 'model_weights': {}, # Now handled by ModelManager
|
||||
'realtime_processing': self.realtime_processing
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user