This commit is contained in:
Dobromir Popov
2025-08-10 21:28:49 +03:00
parent 7289366a35
commit 8738f02d24
5 changed files with 311 additions and 78 deletions

View File

@@ -26,8 +26,6 @@ class ModelsTrainingPanel:
data = self._gather_data()
content: List[html.Div] = []
content.append(self._create_header())
content.append(self._create_models_section(data.get("models", {})))
if data.get("training_status"):
@@ -55,13 +53,21 @@ class ModelsTrainingPanel:
if isinstance(stats, dict):
stats_summary = stats
# Model states (for best_loss and checkpoint flags)
model_states: Dict[str, Dict[str, Any]] = getattr(self.orchestrator, "model_states", {}) or {}
# Model states (for best_loss and checkpoint flags) - use get_model_states() for updated checkpoint info
model_states: Dict[str, Dict[str, Any]] = {}
if hasattr(self.orchestrator, "get_model_states"):
try:
model_states = self.orchestrator.get_model_states() or {}
except Exception as e:
logger.debug(f"Error getting model states: {e}")
model_states = getattr(self.orchestrator, "model_states", {}) or {}
# Build models block from stats_summary
for model_name, s in stats_summary.items():
model_info: Dict[str, Any] = {}
try:
# Use actual model name - no mapping needed
display_name = model_name.upper()
# Status: active if we have any inference info
total_inf = int(s.get("total_inferences", 0) or 0)
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
@@ -84,22 +90,35 @@ class ModelsTrainingPanel:
"last_training": s.get("last_training_time"),
}
# Checkpoint flags from orchestrator.model_states if present
# Checkpoint flags from orchestrator.model_states - map external names to internal keys
ckpt_loaded = False
ckpt_failed = False
if model_states and model_name in model_states:
ckpt_loaded = bool(model_states[model_name].get("checkpoint_loaded", False))
ckpt_failed = bool(model_states[model_name].get("checkpoint_failed", False))
checkpoint_filename = None
# Map external model names to internal model state keys
state_key_map = {
"dqn_agent": "dqn",
"enhanced_cnn": "cnn",
"extrema_trainer": "extrema_trainer",
"decision": "decision",
"cob_rl_model": "cob_rl"
}
state_key = state_key_map.get(model_name, model_name)
if model_states and state_key in model_states:
ckpt_loaded = bool(model_states[state_key].get("checkpoint_loaded", False))
ckpt_failed = bool(model_states[state_key].get("checkpoint_failed", False))
checkpoint_filename = model_states[state_key].get("checkpoint_filename")
# If model started fresh, mark as stored after first successful autosave/best_loss update
if not ckpt_loaded:
cur = model_states[model_name].get("current_loss")
best = model_states[model_name].get("best_loss")
cur = model_states[state_key].get("current_loss")
best = model_states[state_key].get("best_loss")
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
ckpt_loaded = True
checkpoint_filename = model_states[model_name].get("checkpoint_filename")
model_info = {
"name": model_name,
"name": display_name,
"status": status,
"parameters": 0, # unknown; do not synthesize
"last_prediction": {
@@ -112,7 +131,7 @@ class ModelsTrainingPanel:
"routing_enabled": True,
"checkpoint_loaded": ckpt_loaded,
"checkpoint_failed": ckpt_failed,
"checkpoint_filename": checkpoint_filename if model_states and model_name in model_states else None,
"checkpoint_filename": checkpoint_filename,
"loss_metrics": loss_metrics,
"timing_metrics": timing_metrics,
"signal_stats": {},
@@ -145,12 +164,6 @@ class ModelsTrainingPanel:
logger.error(f"Error gathering panel data: {e}")
return result
def _create_header(self) -> html.Div:
return html.Div([
html.H6([html.I(className="fas fa-brain me-2 text-primary"), "Models & Training Progress"], className="mb-2"),
html.Button([html.I(className="fas fa-sync-alt me-1"), "Refresh"], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2")
], className="d-flex justify-content-between align-items-start")
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
cards = [self._create_model_card(name, info) for name, info in models.items()]
return html.Div([