|
|
|
|
@@ -32,9 +32,9 @@ import torch.optim as optim
|
|
|
|
|
from .config import get_config
|
|
|
|
|
from .data_provider import DataProvider
|
|
|
|
|
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
|
|
|
|
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
|
|
|
|
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
|
|
|
|
from NN.training.model_manager import create_model_manager, ModelManager, ModelMetrics, CheckpointMetadata
|
|
|
|
|
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file
|
|
|
|
|
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
|
|
|
|
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
|
|
|
|
|
|
|
|
|
# Import COB integration for real-time market microstructure data
|
|
|
|
|
@@ -92,12 +92,12 @@ class TradingOrchestrator:
|
|
|
|
|
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
|
|
|
|
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_manager: Optional[ModelManager] = None):
|
|
|
|
|
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
|
|
|
|
self.config = get_config()
|
|
|
|
|
self.data_provider = data_provider or DataProvider()
|
|
|
|
|
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
|
|
|
|
self.model_registry = model_registry or get_model_registry()
|
|
|
|
|
self.model_manager = model_manager or create_model_manager()
|
|
|
|
|
self.enhanced_rl_training = enhanced_rl_training
|
|
|
|
|
|
|
|
|
|
# Configuration - AGGRESSIVE for more training data
|
|
|
|
|
@@ -114,14 +114,12 @@ class TradingOrchestrator:
|
|
|
|
|
self.current_positions: Dict[str, Dict] = {} # {symbol: {side, size, entry_price, entry_time, pnl}}
|
|
|
|
|
self.trading_executor = None # Will be set by dashboard or external system
|
|
|
|
|
|
|
|
|
|
# Dynamic weights (will be adapted based on performance)
|
|
|
|
|
self.model_weights: Dict[str, float] = {} # {model_name: weight}
|
|
|
|
|
self._initialize_default_weights()
|
|
|
|
|
|
|
|
|
|
# Model management delegated to unified ModelManager
|
|
|
|
|
# self.model_weights and self.model_performance are now handled by self.model_manager
|
|
|
|
|
|
|
|
|
|
# State tracking
|
|
|
|
|
self.last_decision_time: Dict[str, datetime] = {} # {symbol: datetime}
|
|
|
|
|
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
|
|
|
|
|
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
|
|
|
|
|
|
|
|
|
# Model prediction tracking for dashboard visualization
|
|
|
|
|
self.recent_dqn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent DQN predictions
|
|
|
|
|
@@ -228,7 +226,7 @@ class TradingOrchestrator:
|
|
|
|
|
try:
|
|
|
|
|
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
|
|
|
|
# Check if we have checkpoints available
|
|
|
|
|
from utils.checkpoint_manager import load_best_checkpoint
|
|
|
|
|
from NN.training.model_manager import load_best_checkpoint
|
|
|
|
|
result = load_best_checkpoint("dqn_agent")
|
|
|
|
|
if result:
|
|
|
|
|
file_path, metadata = result
|
|
|
|
|
@@ -268,7 +266,7 @@ class TradingOrchestrator:
|
|
|
|
|
# Load best checkpoint and capture initial state
|
|
|
|
|
checkpoint_loaded = False
|
|
|
|
|
try:
|
|
|
|
|
from utils.checkpoint_manager import load_best_checkpoint
|
|
|
|
|
from NN.training.model_manager import load_best_checkpoint
|
|
|
|
|
result = load_best_checkpoint("enhanced_cnn")
|
|
|
|
|
if result:
|
|
|
|
|
file_path, metadata = result
|
|
|
|
|
@@ -374,7 +372,7 @@ class TradingOrchestrator:
|
|
|
|
|
# Load best checkpoint
|
|
|
|
|
checkpoint_loaded = False
|
|
|
|
|
try:
|
|
|
|
|
from utils.checkpoint_manager import load_best_checkpoint
|
|
|
|
|
from NN.training.model_manager import load_best_checkpoint
|
|
|
|
|
result = load_best_checkpoint("transformer")
|
|
|
|
|
if result:
|
|
|
|
|
file_path, metadata = result
|
|
|
|
|
@@ -408,7 +406,7 @@ class TradingOrchestrator:
|
|
|
|
|
# Load best checkpoint
|
|
|
|
|
checkpoint_loaded = False
|
|
|
|
|
try:
|
|
|
|
|
from utils.checkpoint_manager import load_best_checkpoint
|
|
|
|
|
from NN.training.model_manager import load_best_checkpoint
|
|
|
|
|
result = load_best_checkpoint("decision")
|
|
|
|
|
if result:
|
|
|
|
|
file_path, metadata = result
|
|
|
|
|
@@ -455,7 +453,7 @@ class TradingOrchestrator:
|
|
|
|
|
if self.rl_agent:
|
|
|
|
|
try:
|
|
|
|
|
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
|
|
|
|
self.register_model(rl_interface, weight=0.3)
|
|
|
|
|
# RL model registration handled by ModelManager
|
|
|
|
|
logger.info("RL Agent registered successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to register RL Agent: {e}")
|
|
|
|
|
@@ -464,7 +462,7 @@ class TradingOrchestrator:
|
|
|
|
|
if self.cnn_model:
|
|
|
|
|
try:
|
|
|
|
|
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
|
|
|
|
self.register_model(cnn_interface, weight=0.4)
|
|
|
|
|
# CNN model registration handled by ModelManager
|
|
|
|
|
logger.info("CNN Model registered successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to register CNN Model: {e}")
|
|
|
|
|
@@ -490,7 +488,7 @@ class TradingOrchestrator:
|
|
|
|
|
return 30.0 # MB
|
|
|
|
|
|
|
|
|
|
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
|
|
|
|
self.register_model(extrema_interface, weight=0.15) # Lower weight for extrema signals
|
|
|
|
|
# Extrema model registration handled by ModelManager
|
|
|
|
|
logger.info("Extrema Trainer registered successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to register Extrema Trainer: {e}")
|
|
|
|
|
@@ -521,7 +519,7 @@ class TradingOrchestrator:
|
|
|
|
|
return 60.0 # MB estimate for transformer
|
|
|
|
|
|
|
|
|
|
transformer_interface = TransformerModelInterface(self.transformer_model, self.transformer_trainer, name="transformer")
|
|
|
|
|
self.register_model(transformer_interface, weight=0.2)
|
|
|
|
|
# Transformer model registration handled by ModelManager
|
|
|
|
|
logger.info("Transformer Model registered successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to register Transformer Model: {e}")
|
|
|
|
|
@@ -547,14 +545,14 @@ class TradingOrchestrator:
|
|
|
|
|
return 40.0 # MB estimate for decision model
|
|
|
|
|
|
|
|
|
|
decision_interface = DecisionModelInterface(self.decision_model, name="decision")
|
|
|
|
|
self.register_model(decision_interface, weight=0.15)
|
|
|
|
|
# Decision model registration handled by ModelManager
|
|
|
|
|
logger.info("Decision Fusion Model registered successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to register Decision Fusion Model: {e}")
|
|
|
|
|
|
|
|
|
|
# Normalize weights after all registrations
|
|
|
|
|
self._normalize_weights()
|
|
|
|
|
logger.info(f"Current model weights: {self.model_weights}")
|
|
|
|
|
# Model weight normalization handled by ModelManager
|
|
|
|
|
# Model weights now handled by ModelManager
|
|
|
|
|
logger.info("Model management delegated to unified ModelManager")
|
|
|
|
|
logger.info("COB_RL model removed - cleaner architecture pending COB data quality fixes")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
@@ -627,7 +625,7 @@ class TradingOrchestrator:
|
|
|
|
|
state = {
|
|
|
|
|
'model_states': {k: {sk: sv for sk, sv in v.items() if sk != 'checkpoint_loaded'} # Exclude non-serializable
|
|
|
|
|
for k, v in self.model_states.items()},
|
|
|
|
|
'model_weights': self.model_weights,
|
|
|
|
|
# 'model_weights': self.model_weights, # Now handled by ModelManager
|
|
|
|
|
'last_trained_symbols': list(self.last_trained_symbols.keys())
|
|
|
|
|
}
|
|
|
|
|
save_path = os.path.join(self.config.paths.get('checkpoint_dir', './models/saved'), 'orchestrator_state.json')
|
|
|
|
|
@@ -644,7 +642,7 @@ class TradingOrchestrator:
|
|
|
|
|
with open(save_path, 'r') as f:
|
|
|
|
|
state = json.load(f)
|
|
|
|
|
self.model_states.update(state.get('model_states', {}))
|
|
|
|
|
self.model_weights = state.get('model_weights', self.model_weights)
|
|
|
|
|
# self.model_weights = state.get('model_weights', {}) # Now handled by ModelManager
|
|
|
|
|
self.last_trained_symbols = {s: datetime.now() for s in state.get('last_trained_symbols', [])} # Restore with current time
|
|
|
|
|
logger.info(f"Orchestrator state loaded from {save_path}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
@@ -948,62 +946,10 @@ class TradingOrchestrator:
|
|
|
|
|
|
|
|
|
|
return np.array(padded_features[-sequence_length:]).astype(np.float32) # Ensure correct length
|
|
|
|
|
|
|
|
|
|
def _initialize_default_weights(self):
|
|
|
|
|
"""Initialize default model weights from config"""
|
|
|
|
|
self.model_weights = {
|
|
|
|
|
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
|
|
|
|
|
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
|
|
|
|
}
|
|
|
|
|
# Model management methods removed - all handled by unified ModelManager
|
|
|
|
|
# Use self.model_manager for all model operations
|
|
|
|
|
|
|
|
|
|
def register_model(self, model: ModelInterface, weight: float = None) -> bool:
|
|
|
|
|
"""Register a new model with the orchestrator"""
|
|
|
|
|
try:
|
|
|
|
|
# Register with model registry
|
|
|
|
|
if not self.model_registry.register_model(model):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Set weight
|
|
|
|
|
if weight is not None:
|
|
|
|
|
self.model_weights[model.name] = weight
|
|
|
|
|
elif model.name not in self.model_weights:
|
|
|
|
|
self.model_weights[model.name] = 0.1 # Default low weight for new models
|
|
|
|
|
|
|
|
|
|
# Initialize performance tracking
|
|
|
|
|
if model.name not in self.model_performance:
|
|
|
|
|
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
|
|
|
|
|
|
|
|
|
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
|
|
|
|
self._normalize_weights()
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error registering model {model.name}: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def unregister_model(self, model_name: str) -> bool:
|
|
|
|
|
"""Unregister a model"""
|
|
|
|
|
try:
|
|
|
|
|
if self.model_registry.unregister_model(model_name):
|
|
|
|
|
if model_name in self.model_weights:
|
|
|
|
|
del self.model_weights[model_name]
|
|
|
|
|
if model_name in self.model_performance:
|
|
|
|
|
del self.model_performance[model_name]
|
|
|
|
|
|
|
|
|
|
self._normalize_weights()
|
|
|
|
|
logger.info(f"Unregistered {model_name} model")
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error unregistering model {model_name}: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _normalize_weights(self):
|
|
|
|
|
"""Normalize model weights to sum to 1.0"""
|
|
|
|
|
total_weight = sum(self.model_weights.values())
|
|
|
|
|
if total_weight > 0:
|
|
|
|
|
for model_name in self.model_weights:
|
|
|
|
|
self.model_weights[model_name] /= total_weight
|
|
|
|
|
# Weight normalization removed - handled by ModelManager
|
|
|
|
|
|
|
|
|
|
def add_decision_callback(self, callback):
|
|
|
|
|
"""Add a callback function to be called when decisions are made"""
|
|
|
|
|
@@ -1066,9 +1012,7 @@ class TradingOrchestrator:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in decision callback: {e}")
|
|
|
|
|
|
|
|
|
|
# Clean up memory periodically
|
|
|
|
|
if len(self.recent_decisions[symbol]) % 50 == 0:
|
|
|
|
|
self.model_registry.cleanup_all_models()
|
|
|
|
|
# Model cleanup handled by ModelManager
|
|
|
|
|
|
|
|
|
|
return decision
|
|
|
|
|
|
|
|
|
|
@@ -1077,38 +1021,17 @@ class TradingOrchestrator:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
|
|
|
|
"""Get predictions from all registered models"""
|
|
|
|
|
"""Get predictions from all registered models via ModelManager"""
|
|
|
|
|
predictions = []
|
|
|
|
|
|
|
|
|
|
for model_name, model in self.model_registry.models.items():
|
|
|
|
|
try:
|
|
|
|
|
if isinstance(model, CNNModelInterface):
|
|
|
|
|
# Get CNN predictions for each timeframe
|
|
|
|
|
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
|
|
|
|
predictions.extend(cnn_predictions)
|
|
|
|
|
|
|
|
|
|
elif isinstance(model, RLAgentInterface):
|
|
|
|
|
# Get RL prediction
|
|
|
|
|
rl_prediction = await self._get_rl_prediction(model, symbol)
|
|
|
|
|
if rl_prediction:
|
|
|
|
|
predictions.append(rl_prediction)
|
|
|
|
|
|
|
|
|
|
elif isinstance(model, COBRLModelInterface):
|
|
|
|
|
# Get COB RL prediction
|
|
|
|
|
cob_prediction = await self._get_cob_rl_prediction(model, symbol)
|
|
|
|
|
if cob_prediction:
|
|
|
|
|
predictions.append(cob_prediction)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# Generic model interface
|
|
|
|
|
generic_prediction = await self._get_generic_prediction(model, symbol)
|
|
|
|
|
if generic_prediction:
|
|
|
|
|
predictions.append(generic_prediction)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting prediction from {model_name}: {e}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# This method now delegates to ModelManager for model iteration
|
|
|
|
|
# The actual model prediction logic has been moved to individual methods
|
|
|
|
|
# that are called by the ModelManager
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Getting predictions for {symbol} - model management handled by ModelManager")
|
|
|
|
|
|
|
|
|
|
# For now, return empty list as this method needs to be restructured
|
|
|
|
|
# to work with the new ModelManager architecture
|
|
|
|
|
return predictions
|
|
|
|
|
|
|
|
|
|
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
|
|
|
|
@@ -1454,7 +1377,7 @@ class TradingOrchestrator:
|
|
|
|
|
try:
|
|
|
|
|
reasoning = {
|
|
|
|
|
'predictions': len(predictions),
|
|
|
|
|
'weights': self.model_weights.copy(),
|
|
|
|
|
# 'weights': {}, # Now handled by ModelManager
|
|
|
|
|
'models_used': [pred.model_name for pred in predictions]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1468,7 +1391,7 @@ class TradingOrchestrator:
|
|
|
|
|
# Process all predictions
|
|
|
|
|
for pred in predictions:
|
|
|
|
|
# Get model weight
|
|
|
|
|
model_weight = self.model_weights.get(pred.model_name, 0.1)
|
|
|
|
|
model_weight = 0.1 # Default weight, now managed by ModelManager
|
|
|
|
|
|
|
|
|
|
# Weight by confidence and timeframe importance
|
|
|
|
|
timeframe_weight = self._get_timeframe_weight(pred.timeframe)
|
|
|
|
|
@@ -1518,7 +1441,7 @@ class TradingOrchestrator:
|
|
|
|
|
|
|
|
|
|
# Get memory usage stats
|
|
|
|
|
try:
|
|
|
|
|
memory_usage = self.model_registry.get_memory_stats() if hasattr(self.model_registry, 'get_memory_stats') else {}
|
|
|
|
|
memory_usage = self.model_manager.get_storage_stats() if hasattr(self.model_manager, 'get_storage_stats') else {}
|
|
|
|
|
except Exception:
|
|
|
|
|
memory_usage = {}
|
|
|
|
|
|
|
|
|
|
@@ -1571,31 +1494,8 @@ class TradingOrchestrator:
|
|
|
|
|
}
|
|
|
|
|
return weights.get(timeframe, 0.5)
|
|
|
|
|
|
|
|
|
|
def update_model_performance(self, model_name: str, was_correct: bool):
|
|
|
|
|
"""Update performance tracking for a model"""
|
|
|
|
|
if model_name in self.model_performance:
|
|
|
|
|
self.model_performance[model_name]['total'] += 1
|
|
|
|
|
if was_correct:
|
|
|
|
|
self.model_performance[model_name]['correct'] += 1
|
|
|
|
|
|
|
|
|
|
# Update accuracy
|
|
|
|
|
total = self.model_performance[model_name]['total']
|
|
|
|
|
correct = self.model_performance[model_name]['correct']
|
|
|
|
|
self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
def adapt_weights(self):
|
|
|
|
|
"""Dynamically adapt model weights based on performance"""
|
|
|
|
|
try:
|
|
|
|
|
for model_name, performance in self.model_performance.items():
|
|
|
|
|
if performance['total'] > 0:
|
|
|
|
|
# Adjust weight based on relative performance
|
|
|
|
|
accuracy = performance['correct'] / performance['total']
|
|
|
|
|
self.model_weights[model_name] = accuracy
|
|
|
|
|
|
|
|
|
|
logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error adapting weights: {e}")
|
|
|
|
|
# Model performance and weight adaptation removed - handled by ModelManager
|
|
|
|
|
# Use self.model_manager for all model performance tracking
|
|
|
|
|
|
|
|
|
|
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
|
|
|
|
|
"""Get recent decisions for a symbol"""
|
|
|
|
|
@@ -1606,8 +1506,8 @@ class TradingOrchestrator:
|
|
|
|
|
def get_performance_metrics(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get performance metrics for the orchestrator"""
|
|
|
|
|
return {
|
|
|
|
|
'model_performance': self.model_performance.copy(),
|
|
|
|
|
'weights': self.model_weights.copy(),
|
|
|
|
|
# 'model_performance': {}, # Now handled by ModelManager
|
|
|
|
|
# 'weights': {}, # Now handled by ModelManager
|
|
|
|
|
'configuration': {
|
|
|
|
|
'confidence_threshold': self.confidence_threshold,
|
|
|
|
|
'decision_frequency': self.decision_frequency
|
|
|
|
|
@@ -1630,7 +1530,7 @@ class TradingOrchestrator:
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
cache_expiry = 60 # seconds
|
|
|
|
|
|
|
|
|
|
from utils.checkpoint_manager import load_best_checkpoint
|
|
|
|
|
from NN.training.model_manager import load_best_checkpoint
|
|
|
|
|
|
|
|
|
|
# Update each model with REAL checkpoint data (cached)
|
|
|
|
|
# Note: COB_RL removed - functionality integrated into Enhanced CNN
|
|
|
|
|
@@ -1872,7 +1772,7 @@ class TradingOrchestrator:
|
|
|
|
|
'decision_fusion_enabled': self.decision_fusion_enabled,
|
|
|
|
|
'symbols_tracking': len(self.symbols),
|
|
|
|
|
'recent_decisions_count': sum(len(decisions) for decisions in self.recent_decisions.values()),
|
|
|
|
|
'model_weights': self.model_weights.copy(),
|
|
|
|
|
# 'model_weights': {}, # Now handled by ModelManager
|
|
|
|
|
'realtime_processing': self.realtime_processing
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|