ui tweaks
This commit is contained in:
@@ -632,31 +632,63 @@ class AnnotationDashboard:
|
||||
self.training_adapter.orchestrator = self.orchestrator
|
||||
logger.info("TradingOrchestrator initialized")
|
||||
|
||||
# Get checkpoint info before loading
|
||||
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||
|
||||
# Load the specific model
|
||||
if model_name == 'Transformer':
|
||||
logger.info("Loading Transformer model...")
|
||||
self.orchestrator.load_transformer_model()
|
||||
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer_trainer
|
||||
|
||||
# Store checkpoint info in orchestrator for UI access
|
||||
if checkpoint_info:
|
||||
self.orchestrator.transformer_checkpoint_info = {
|
||||
'status': 'loaded',
|
||||
'filename': checkpoint_info.get('filename', 'unknown'),
|
||||
'epoch': checkpoint_info.get('epoch', 0),
|
||||
'loss': checkpoint_info.get('loss', 0.0),
|
||||
'accuracy': checkpoint_info.get('accuracy', 0.0),
|
||||
'loaded_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
logger.info("Transformer model loaded successfully")
|
||||
|
||||
elif model_name == 'CNN':
|
||||
logger.info("Loading CNN model...")
|
||||
self.orchestrator.load_cnn_model()
|
||||
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
||||
|
||||
# Store checkpoint info
|
||||
if checkpoint_info:
|
||||
self.orchestrator.cnn_checkpoint_info = {
|
||||
'status': 'loaded',
|
||||
'filename': checkpoint_info.get('filename', 'unknown'),
|
||||
'loaded_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
logger.info("CNN model loaded successfully")
|
||||
|
||||
elif model_name == 'DQN':
|
||||
logger.info("Loading DQN model...")
|
||||
self.orchestrator.load_dqn_model()
|
||||
self.loaded_models['DQN'] = self.orchestrator.dqn_agent
|
||||
|
||||
# Store checkpoint info
|
||||
if checkpoint_info:
|
||||
self.orchestrator.dqn_checkpoint_info = {
|
||||
'status': 'loaded',
|
||||
'filename': checkpoint_info.get('filename', 'unknown'),
|
||||
'loaded_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
logger.info("DQN model loaded successfully")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown model name: {model_name}")
|
||||
return
|
||||
|
||||
# Get checkpoint info for display
|
||||
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||
# Log checkpoint info
|
||||
if checkpoint_info:
|
||||
logger.info(f" Checkpoint: {checkpoint_info.get('filename', 'N/A')}")
|
||||
if checkpoint_info.get('accuracy'):
|
||||
@@ -2013,23 +2045,27 @@ class AnnotationDashboard:
|
||||
logger.warning(f"self.available_models is not a list: {type(self.available_models)}. Resetting to default.")
|
||||
self.available_models = ['Transformer', 'COB_RL', 'CNN', 'DQN']
|
||||
|
||||
# Ensure self.loaded_models is a list/set
|
||||
# Ensure self.loaded_models exists (it's a dict)
|
||||
if not hasattr(self, 'loaded_models'):
|
||||
self.loaded_models = []
|
||||
self.loaded_models = {}
|
||||
|
||||
# Build model state dict with checkpoint info
|
||||
logger.info(f"Building model states for {len(self.available_models)} models: {self.available_models}")
|
||||
logger.info(f"Currently loaded models: {list(self.loaded_models.keys())}")
|
||||
model_states = []
|
||||
for model_name in self.available_models:
|
||||
is_loaded = model_name in self.loaded_models
|
||||
# Check if model is in loaded_models dict
|
||||
is_loaded = model_name in self.loaded_models and self.loaded_models[model_name] is not None
|
||||
|
||||
# Get checkpoint info (even for unloaded models)
|
||||
checkpoint_info = None
|
||||
|
||||
# If loaded, get from orchestrator
|
||||
if is_loaded and self.orchestrator:
|
||||
if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
|
||||
cp_info = self.orchestrator.transformer_checkpoint_info
|
||||
checkpoint_attr = f"{model_name.lower()}_checkpoint_info"
|
||||
|
||||
if hasattr(self.orchestrator, checkpoint_attr):
|
||||
cp_info = getattr(self.orchestrator, checkpoint_attr)
|
||||
if cp_info and cp_info.get('status') == 'loaded':
|
||||
checkpoint_info = {
|
||||
'filename': cp_info.get('filename', 'unknown'),
|
||||
|
||||
Reference in New Issue
Block a user