fixes
This commit is contained in:
@@ -26,8 +26,6 @@ class ModelsTrainingPanel:
|
||||
data = self._gather_data()
|
||||
|
||||
content: List[html.Div] = []
|
||||
content.append(self._create_header())
|
||||
|
||||
content.append(self._create_models_section(data.get("models", {})))
|
||||
|
||||
if data.get("training_status"):
|
||||
@@ -55,13 +53,21 @@ class ModelsTrainingPanel:
|
||||
if isinstance(stats, dict):
|
||||
stats_summary = stats
|
||||
|
||||
# Model states (for best_loss and checkpoint flags)
|
||||
model_states: Dict[str, Dict[str, Any]] = getattr(self.orchestrator, "model_states", {}) or {}
|
||||
# Model states (for best_loss and checkpoint flags) - use get_model_states() for updated checkpoint info
|
||||
model_states: Dict[str, Dict[str, Any]] = {}
|
||||
if hasattr(self.orchestrator, "get_model_states"):
|
||||
try:
|
||||
model_states = self.orchestrator.get_model_states() or {}
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting model states: {e}")
|
||||
model_states = getattr(self.orchestrator, "model_states", {}) or {}
|
||||
|
||||
# Build models block from stats_summary
|
||||
for model_name, s in stats_summary.items():
|
||||
model_info: Dict[str, Any] = {}
|
||||
try:
|
||||
# Use actual model name - no mapping needed
|
||||
display_name = model_name.upper()
|
||||
# Status: active if we have any inference info
|
||||
total_inf = int(s.get("total_inferences", 0) or 0)
|
||||
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
|
||||
@@ -84,22 +90,35 @@ class ModelsTrainingPanel:
|
||||
"last_training": s.get("last_training_time"),
|
||||
}
|
||||
|
||||
# Checkpoint flags from orchestrator.model_states if present
|
||||
# Checkpoint flags from orchestrator.model_states - map external names to internal keys
|
||||
ckpt_loaded = False
|
||||
ckpt_failed = False
|
||||
if model_states and model_name in model_states:
|
||||
ckpt_loaded = bool(model_states[model_name].get("checkpoint_loaded", False))
|
||||
ckpt_failed = bool(model_states[model_name].get("checkpoint_failed", False))
|
||||
checkpoint_filename = None
|
||||
|
||||
# Map external model names to internal model state keys
|
||||
state_key_map = {
|
||||
"dqn_agent": "dqn",
|
||||
"enhanced_cnn": "cnn",
|
||||
"extrema_trainer": "extrema_trainer",
|
||||
"decision": "decision",
|
||||
"cob_rl_model": "cob_rl"
|
||||
}
|
||||
state_key = state_key_map.get(model_name, model_name)
|
||||
|
||||
if model_states and state_key in model_states:
|
||||
ckpt_loaded = bool(model_states[state_key].get("checkpoint_loaded", False))
|
||||
ckpt_failed = bool(model_states[state_key].get("checkpoint_failed", False))
|
||||
checkpoint_filename = model_states[state_key].get("checkpoint_filename")
|
||||
|
||||
# If model started fresh, mark as stored after first successful autosave/best_loss update
|
||||
if not ckpt_loaded:
|
||||
cur = model_states[model_name].get("current_loss")
|
||||
best = model_states[model_name].get("best_loss")
|
||||
cur = model_states[state_key].get("current_loss")
|
||||
best = model_states[state_key].get("best_loss")
|
||||
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
|
||||
ckpt_loaded = True
|
||||
checkpoint_filename = model_states[model_name].get("checkpoint_filename")
|
||||
|
||||
model_info = {
|
||||
"name": model_name,
|
||||
"name": display_name,
|
||||
"status": status,
|
||||
"parameters": 0, # unknown; do not synthesize
|
||||
"last_prediction": {
|
||||
@@ -112,7 +131,7 @@ class ModelsTrainingPanel:
|
||||
"routing_enabled": True,
|
||||
"checkpoint_loaded": ckpt_loaded,
|
||||
"checkpoint_failed": ckpt_failed,
|
||||
"checkpoint_filename": checkpoint_filename if model_states and model_name in model_states else None,
|
||||
"checkpoint_filename": checkpoint_filename,
|
||||
"loss_metrics": loss_metrics,
|
||||
"timing_metrics": timing_metrics,
|
||||
"signal_stats": {},
|
||||
@@ -145,12 +164,6 @@ class ModelsTrainingPanel:
|
||||
logger.error(f"Error gathering panel data: {e}")
|
||||
return result
|
||||
|
||||
def _create_header(self) -> html.Div:
|
||||
return html.Div([
|
||||
html.H6([html.I(className="fas fa-brain me-2 text-primary"), "Models & Training Progress"], className="mb-2"),
|
||||
html.Button([html.I(className="fas fa-sync-alt me-1"), "Refresh"], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2")
|
||||
], className="d-flex justify-content-between align-items-start")
|
||||
|
||||
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
|
||||
cards = [self._create_model_card(name, info) for name, info in models.items()]
|
||||
return html.Div([
|
||||
|
Reference in New Issue
Block a user