inf/trn toggles UI

This commit is contained in:
Dobromir Popov
2025-07-29 15:51:18 +03:00
parent ff41f0a278
commit ecbbabc0c1
3 changed files with 503 additions and 239 deletions

View File

@ -1433,17 +1433,81 @@ class TradingOrchestrator:
return self.model_toggle_states.get(model_name, {"inference_enabled": True, "training_enabled": True})
def set_model_toggle_state(self, model_name: str, inference_enabled: bool = None, training_enabled: bool = None):
"""Set toggle state for a model"""
"""Set toggle state for a model - Universal handler for any model"""
# Initialize model toggle state if it doesn't exist
if model_name not in self.model_toggle_states:
self.model_toggle_states[model_name] = {"inference_enabled": True, "training_enabled": True}
logger.info(f"Initialized toggle state for new model: {model_name}")
# Update the toggle states
if inference_enabled is not None:
self.model_toggle_states[model_name]["inference_enabled"] = inference_enabled
if training_enabled is not None:
self.model_toggle_states[model_name]["training_enabled"] = training_enabled
# Save the updated state
self._save_ui_state()
# Log the change
logger.info(f"Model {model_name} toggle state updated: inference={self.model_toggle_states[model_name]['inference_enabled']}, training={self.model_toggle_states[model_name]['training_enabled']}")
# Notify any listeners about the toggle change
self._notify_model_toggle_change(model_name, self.model_toggle_states[model_name])
def _notify_model_toggle_change(self, model_name: str, toggle_state: Dict[str, bool]):
"""Notify components about model toggle changes"""
try:
# This can be extended to notify other components
# For now, just log the change
logger.debug(f"Model toggle change notification: {model_name} -> {toggle_state}")
except Exception as e:
logger.debug(f"Error notifying model toggle change: {e}")
def register_model_dynamically(self, model_name: str, model_interface):
"""Register a new model dynamically and set up its toggle state"""
try:
# Register with model registry
if self.model_registry.register_model(model_interface):
# Initialize toggle state for the new model
if model_name not in self.model_toggle_states:
self.model_toggle_states[model_name] = {
"inference_enabled": True,
"training_enabled": True
}
logger.info(f"Registered new model dynamically: {model_name}")
self._save_ui_state()
return True
return False
except Exception as e:
logger.error(f"Error registering model {model_name} dynamically: {e}")
return False
def get_all_registered_models(self):
"""Get all registered models from registry and toggle states"""
try:
all_models = {}
# Get models from registry
if hasattr(self, 'model_registry') and self.model_registry:
registry_models = self.model_registry.get_all_models()
all_models.update(registry_models)
# Add any models that have toggle states but aren't in registry
for model_name in self.model_toggle_states.keys():
if model_name not in all_models:
all_models[model_name] = {
'name': model_name,
'type': 'toggle_only',
'registered': False
}
return all_models
except Exception as e:
logger.error(f"Error getting all registered models: {e}")
return {}
def is_model_inference_enabled(self, model_name: str) -> bool:
"""Check if model inference is enabled"""
@ -5371,18 +5435,21 @@ class TradingOrchestrator:
if 'model_state_dict' in checkpoint:
self.decision_fusion_network.load_state_dict(checkpoint['model_state_dict'])
# Update model states
self.model_states["decision"]["initial_loss"] = (
# Update model states - FIX: Use correct key "decision_fusion"
if "decision_fusion" not in self.model_states:
self.model_states["decision_fusion"] = {}
self.model_states["decision_fusion"]["initial_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["decision"]["current_loss"] = (
self.model_states["decision_fusion"]["current_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["decision"]["best_loss"] = (
self.model_states["decision_fusion"]["best_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["decision"]["checkpoint_loaded"] = True
self.model_states["decision"][
self.model_states["decision_fusion"]["checkpoint_loaded"] = True
self.model_states["decision_fusion"][
"checkpoint_filename"
] = metadata.checkpoint_id
@ -5398,8 +5465,15 @@ class TradingOrchestrator:
logger.warning(f"Error loading decision fusion checkpoint: {e}")
logger.info("Decision fusion network starting fresh")
# Initialize optimizer for decision fusion training
self.decision_fusion_optimizer = torch.optim.Adam(
self.decision_fusion_network.parameters(),
lr=decision_fusion_config.get("learning_rate", 0.001)
)
logger.info(f"Decision fusion network initialized on device: {self.device}")
logger.info(f"Decision fusion mode: {self.decision_fusion_mode}")
logger.info(f"Decision fusion optimizer initialized with lr={decision_fusion_config.get('learning_rate', 0.001)}")
except Exception as e:
logger.warning(f"Decision fusion initialization failed: {e}")
@ -5454,6 +5528,64 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error training decision fusion in programmatic mode: {e}")
def _save_decision_fusion_checkpoint(self):
"""Save decision fusion model checkpoint"""
try:
if not self.decision_fusion_network or not self.checkpoint_manager:
return
# Get current performance score
model_stats = self.model_statistics.get('decision_fusion')
performance_score = 0.5 # Default score
if model_stats and model_stats.accuracy is not None:
performance_score = model_stats.accuracy
elif hasattr(self, 'decision_fusion_performance_score'):
performance_score = self.decision_fusion_performance_score
# Create checkpoint data
checkpoint_data = {
'model_state_dict': self.decision_fusion_network.state_dict(),
'optimizer_state_dict': self.decision_fusion_optimizer.state_dict() if hasattr(self, 'decision_fusion_optimizer') else None,
'epoch': self.decision_fusion_decisions_count,
'loss': 1.0 - performance_score, # Convert performance to loss
'performance_score': performance_score,
'timestamp': datetime.now().isoformat(),
'model_name': 'decision_fusion',
'training_data_count': len(self.decision_fusion_training_data)
}
# Save checkpoint using checkpoint manager
checkpoint_path = self.checkpoint_manager.save_model_checkpoint(
model_name="decision_fusion",
model_data=checkpoint_data,
loss=1.0 - performance_score,
performance_score=performance_score
)
if checkpoint_path:
logger.info(f"Decision fusion checkpoint saved: {checkpoint_path}")
# Update model state
if 'decision_fusion' not in self.model_states:
self.model_states['decision_fusion'] = {}
self.model_states['decision_fusion'].update({
'checkpoint_loaded': True,
'checkpoint_filename': checkpoint_path.name if hasattr(checkpoint_path, 'name') else str(checkpoint_path),
'current_loss': 1.0 - performance_score,
'best_loss': min(self.model_states['decision_fusion'].get('best_loss', float('inf')), 1.0 - performance_score),
'last_training': datetime.now(),
'performance_score': performance_score
})
logger.info(f"Decision fusion model state updated with checkpoint info")
else:
logger.warning("Failed to save decision fusion checkpoint")
except Exception as e:
logger.error(f"Error saving decision fusion checkpoint: {e}")
def _create_decision_fusion_input(
self,
symbol: str,