inf/trn toggles UI
This commit is contained in:
@ -1433,17 +1433,81 @@ class TradingOrchestrator:
|
||||
return self.model_toggle_states.get(model_name, {"inference_enabled": True, "training_enabled": True})
|
||||
|
||||
def set_model_toggle_state(self, model_name: str, inference_enabled: bool = None, training_enabled: bool = None):
|
||||
"""Set toggle state for a model"""
|
||||
"""Set toggle state for a model - Universal handler for any model"""
|
||||
# Initialize model toggle state if it doesn't exist
|
||||
if model_name not in self.model_toggle_states:
|
||||
self.model_toggle_states[model_name] = {"inference_enabled": True, "training_enabled": True}
|
||||
logger.info(f"Initialized toggle state for new model: {model_name}")
|
||||
|
||||
# Update the toggle states
|
||||
if inference_enabled is not None:
|
||||
self.model_toggle_states[model_name]["inference_enabled"] = inference_enabled
|
||||
if training_enabled is not None:
|
||||
self.model_toggle_states[model_name]["training_enabled"] = training_enabled
|
||||
|
||||
# Save the updated state
|
||||
self._save_ui_state()
|
||||
|
||||
# Log the change
|
||||
logger.info(f"Model {model_name} toggle state updated: inference={self.model_toggle_states[model_name]['inference_enabled']}, training={self.model_toggle_states[model_name]['training_enabled']}")
|
||||
|
||||
# Notify any listeners about the toggle change
|
||||
self._notify_model_toggle_change(model_name, self.model_toggle_states[model_name])
|
||||
|
||||
def _notify_model_toggle_change(self, model_name: str, toggle_state: Dict[str, bool]):
|
||||
"""Notify components about model toggle changes"""
|
||||
try:
|
||||
# This can be extended to notify other components
|
||||
# For now, just log the change
|
||||
logger.debug(f"Model toggle change notification: {model_name} -> {toggle_state}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error notifying model toggle change: {e}")
|
||||
|
||||
def register_model_dynamically(self, model_name: str, model_interface):
|
||||
"""Register a new model dynamically and set up its toggle state"""
|
||||
try:
|
||||
# Register with model registry
|
||||
if self.model_registry.register_model(model_interface):
|
||||
# Initialize toggle state for the new model
|
||||
if model_name not in self.model_toggle_states:
|
||||
self.model_toggle_states[model_name] = {
|
||||
"inference_enabled": True,
|
||||
"training_enabled": True
|
||||
}
|
||||
logger.info(f"Registered new model dynamically: {model_name}")
|
||||
self._save_ui_state()
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering model {model_name} dynamically: {e}")
|
||||
return False
|
||||
|
||||
def get_all_registered_models(self):
|
||||
"""Get all registered models from registry and toggle states"""
|
||||
try:
|
||||
all_models = {}
|
||||
|
||||
# Get models from registry
|
||||
if hasattr(self, 'model_registry') and self.model_registry:
|
||||
registry_models = self.model_registry.get_all_models()
|
||||
all_models.update(registry_models)
|
||||
|
||||
# Add any models that have toggle states but aren't in registry
|
||||
for model_name in self.model_toggle_states.keys():
|
||||
if model_name not in all_models:
|
||||
all_models[model_name] = {
|
||||
'name': model_name,
|
||||
'type': 'toggle_only',
|
||||
'registered': False
|
||||
}
|
||||
|
||||
return all_models
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all registered models: {e}")
|
||||
return {}
|
||||
|
||||
def is_model_inference_enabled(self, model_name: str) -> bool:
|
||||
"""Check if model inference is enabled"""
|
||||
@ -5371,18 +5435,21 @@ class TradingOrchestrator:
|
||||
if 'model_state_dict' in checkpoint:
|
||||
self.decision_fusion_network.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Update model states
|
||||
self.model_states["decision"]["initial_loss"] = (
|
||||
# Update model states - FIX: Use correct key "decision_fusion"
|
||||
if "decision_fusion" not in self.model_states:
|
||||
self.model_states["decision_fusion"] = {}
|
||||
|
||||
self.model_states["decision_fusion"]["initial_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["decision"]["current_loss"] = (
|
||||
self.model_states["decision_fusion"]["current_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["decision"]["best_loss"] = (
|
||||
self.model_states["decision_fusion"]["best_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["decision"]["checkpoint_loaded"] = True
|
||||
self.model_states["decision"][
|
||||
self.model_states["decision_fusion"]["checkpoint_loaded"] = True
|
||||
self.model_states["decision_fusion"][
|
||||
"checkpoint_filename"
|
||||
] = metadata.checkpoint_id
|
||||
|
||||
@ -5398,8 +5465,15 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Error loading decision fusion checkpoint: {e}")
|
||||
logger.info("Decision fusion network starting fresh")
|
||||
|
||||
# Initialize optimizer for decision fusion training
|
||||
self.decision_fusion_optimizer = torch.optim.Adam(
|
||||
self.decision_fusion_network.parameters(),
|
||||
lr=decision_fusion_config.get("learning_rate", 0.001)
|
||||
)
|
||||
|
||||
logger.info(f"Decision fusion network initialized on device: {self.device}")
|
||||
logger.info(f"Decision fusion mode: {self.decision_fusion_mode}")
|
||||
logger.info(f"Decision fusion optimizer initialized with lr={decision_fusion_config.get('learning_rate', 0.001)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Decision fusion initialization failed: {e}")
|
||||
@ -5454,6 +5528,64 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error training decision fusion in programmatic mode: {e}")
|
||||
|
||||
def _save_decision_fusion_checkpoint(self):
|
||||
"""Save decision fusion model checkpoint"""
|
||||
try:
|
||||
if not self.decision_fusion_network or not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
# Get current performance score
|
||||
model_stats = self.model_statistics.get('decision_fusion')
|
||||
performance_score = 0.5 # Default score
|
||||
|
||||
if model_stats and model_stats.accuracy is not None:
|
||||
performance_score = model_stats.accuracy
|
||||
elif hasattr(self, 'decision_fusion_performance_score'):
|
||||
performance_score = self.decision_fusion_performance_score
|
||||
|
||||
# Create checkpoint data
|
||||
checkpoint_data = {
|
||||
'model_state_dict': self.decision_fusion_network.state_dict(),
|
||||
'optimizer_state_dict': self.decision_fusion_optimizer.state_dict() if hasattr(self, 'decision_fusion_optimizer') else None,
|
||||
'epoch': self.decision_fusion_decisions_count,
|
||||
'loss': 1.0 - performance_score, # Convert performance to loss
|
||||
'performance_score': performance_score,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model_name': 'decision_fusion',
|
||||
'training_data_count': len(self.decision_fusion_training_data)
|
||||
}
|
||||
|
||||
# Save checkpoint using checkpoint manager
|
||||
checkpoint_path = self.checkpoint_manager.save_model_checkpoint(
|
||||
model_name="decision_fusion",
|
||||
model_data=checkpoint_data,
|
||||
loss=1.0 - performance_score,
|
||||
performance_score=performance_score
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"Decision fusion checkpoint saved: {checkpoint_path}")
|
||||
|
||||
# Update model state
|
||||
if 'decision_fusion' not in self.model_states:
|
||||
self.model_states['decision_fusion'] = {}
|
||||
|
||||
self.model_states['decision_fusion'].update({
|
||||
'checkpoint_loaded': True,
|
||||
'checkpoint_filename': checkpoint_path.name if hasattr(checkpoint_path, 'name') else str(checkpoint_path),
|
||||
'current_loss': 1.0 - performance_score,
|
||||
'best_loss': min(self.model_states['decision_fusion'].get('best_loss', float('inf')), 1.0 - performance_score),
|
||||
'last_training': datetime.now(),
|
||||
'performance_score': performance_score
|
||||
})
|
||||
|
||||
logger.info(f"Decision fusion model state updated with checkpoint info")
|
||||
else:
|
||||
logger.warning("Failed to save decision fusion checkpoint")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving decision fusion checkpoint: {e}")
|
||||
|
||||
def _create_decision_fusion_input(
|
||||
self,
|
||||
symbol: str,
|
||||
|
Reference in New Issue
Block a user