290 lines
16 KiB
Python
290 lines
16 KiB
Python
"""
|
|
Models & Training Progress Panel - Clean Implementation
|
|
Displays real-time model status, training metrics, and performance data
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime
|
|
from dash import html, dcc
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelsTrainingPanel:
|
|
"""Clean implementation of the Models & Training Progress panel.
|
|
- Uses orchestrator.get_model_statistics_summary() and model_states
|
|
- No synthetic or placeholder data; shows empty when unavailable
|
|
- Avoids duplicate class definitions and invalid imports
|
|
"""
|
|
|
|
def __init__(self, orchestrator=None):
|
|
self.orchestrator = orchestrator
|
|
|
|
def create_panel(self) -> Any:
|
|
try:
|
|
data = self._gather_data()
|
|
|
|
content: List[html.Div] = []
|
|
content.append(self._create_models_section(data.get("models", {})))
|
|
|
|
if data.get("training_status"):
|
|
content.append(self._create_training_status_section(data["training_status"]))
|
|
|
|
if data.get("system_metrics"):
|
|
content.append(self._create_system_metrics_section(data["system_metrics"]))
|
|
|
|
# Return children (to be assigned to 'training-metrics' container)
|
|
return content
|
|
except Exception as e:
|
|
logger.error(f"Error creating models training panel: {e}")
|
|
return [
|
|
html.P(f"Error loading training panel: {str(e)}", className="text-danger small")
|
|
]
|
|
|
|
def _gather_data(self) -> Dict[str, Any]:
|
|
result: Dict[str, Any] = {"models": {}, "training_status": {}, "system_metrics": {}}
|
|
if not self.orchestrator:
|
|
return result
|
|
try:
|
|
# Model statistics summary (serializable)
|
|
stats_summary: Dict[str, Dict[str, Any]] = {}
|
|
if hasattr(self.orchestrator, "get_model_statistics_summary"):
|
|
stats = self.orchestrator.get_model_statistics_summary()
|
|
if isinstance(stats, dict):
|
|
stats_summary = stats
|
|
|
|
# Model states (for best_loss and checkpoint flags) - use get_model_states() for updated checkpoint info
|
|
model_states: Dict[str, Dict[str, Any]] = {}
|
|
if hasattr(self.orchestrator, "get_model_states"):
|
|
try:
|
|
model_states = self.orchestrator.get_model_states() or {}
|
|
except Exception as e:
|
|
logger.debug(f"Error getting model states: {e}")
|
|
model_states = getattr(self.orchestrator, "model_states", {}) or {}
|
|
|
|
# Build models block from stats_summary
|
|
for model_name, s in stats_summary.items():
|
|
model_info: Dict[str, Any] = {}
|
|
try:
|
|
# Use actual model name - no mapping needed
|
|
display_name = model_name.upper()
|
|
# Status: active if we have any inference info
|
|
total_inf = int(s.get("total_inferences", 0) or 0)
|
|
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
|
|
|
|
# Last prediction
|
|
last_pred_action = s.get("last_prediction")
|
|
last_confidence = s.get("last_confidence")
|
|
last_inference_time = s.get("last_inference_time")
|
|
|
|
# Loss metrics
|
|
loss_metrics = {
|
|
"current_loss": s.get("current_loss"),
|
|
"best_loss": s.get("best_loss"),
|
|
}
|
|
|
|
# Timing metrics
|
|
timing_metrics = {
|
|
"inferences_per_second": s.get("inference_rate_per_second"),
|
|
"last_inference": last_inference_time,
|
|
"last_training": s.get("last_training_time"),
|
|
}
|
|
|
|
# Checkpoint flags from orchestrator.model_states - map external names to internal keys
|
|
ckpt_loaded = False
|
|
ckpt_failed = False
|
|
checkpoint_filename = None
|
|
|
|
# Map external model names to internal model state keys
|
|
state_key_map = {
|
|
"dqn_agent": "dqn",
|
|
"enhanced_cnn": "cnn",
|
|
"extrema_trainer": "extrema_trainer",
|
|
"decision": "decision",
|
|
"cob_rl_model": "cob_rl"
|
|
}
|
|
state_key = state_key_map.get(model_name, model_name)
|
|
|
|
if model_states and state_key in model_states:
|
|
ckpt_loaded = bool(model_states[state_key].get("checkpoint_loaded", False))
|
|
ckpt_failed = bool(model_states[state_key].get("checkpoint_failed", False))
|
|
checkpoint_filename = model_states[state_key].get("checkpoint_filename")
|
|
|
|
# If model started fresh, mark as stored after first successful autosave/best_loss update
|
|
if not ckpt_loaded:
|
|
cur = model_states[state_key].get("current_loss")
|
|
best = model_states[state_key].get("best_loss")
|
|
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
|
|
ckpt_loaded = True
|
|
|
|
model_info = {
|
|
"name": display_name,
|
|
"status": status,
|
|
"parameters": 0, # unknown; do not synthesize
|
|
"last_prediction": {
|
|
"action": last_pred_action if last_pred_action is not None else "",
|
|
"confidence": last_confidence if last_confidence is not None else None,
|
|
"timestamp": last_inference_time if last_inference_time else "",
|
|
},
|
|
"training_enabled": True,
|
|
"inference_enabled": True,
|
|
"routing_enabled": True,
|
|
"checkpoint_loaded": ckpt_loaded,
|
|
"checkpoint_failed": ckpt_failed,
|
|
"checkpoint_filename": checkpoint_filename,
|
|
"loss_metrics": loss_metrics,
|
|
"timing_metrics": timing_metrics,
|
|
"signal_stats": {},
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f"Error building model info for {model_name}: {e}")
|
|
model_info = {"name": model_name, "status": "error", "error": str(e)}
|
|
|
|
result["models"][model_name] = model_info
|
|
|
|
# Enhanced training system status
|
|
training_status: Dict[str, Any] = {}
|
|
try:
|
|
if hasattr(self.orchestrator, "get_enhanced_training_stats"):
|
|
training_status = self.orchestrator.get_enhanced_training_stats() or {}
|
|
except Exception as e:
|
|
logger.debug(f"Enhanced training stats unavailable: {e}")
|
|
result["training_status"] = training_status
|
|
|
|
# System metrics (decision fusion, cob integration)
|
|
system_metrics = {
|
|
"decision_fusion_active": bool(getattr(self.orchestrator, "decision_fusion_enabled", False)),
|
|
"cob_integration_active": bool(getattr(self.orchestrator, "cob_integration", None) is not None),
|
|
"symbols_tracking": len(getattr(getattr(self.orchestrator, "cob_integration", None), "symbols", []) or []),
|
|
}
|
|
result["system_metrics"] = system_metrics
|
|
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error gathering panel data: {e}")
|
|
return result
|
|
|
|
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
|
|
cards = [self._create_model_card(name, info) for name, info in models.items()]
|
|
return html.Div([
|
|
html.H6([html.I(className="fas fa-microchip me-2 text-success"), f"Loaded Models ({len(models)})"], className="mb-2"),
|
|
html.Div(cards)
|
|
])
|
|
|
|
def _create_model_card(self, model_name: str, model_data: Dict[str, Any]) -> html.Div:
|
|
status = model_data.get("status", "unknown")
|
|
if status == "active":
|
|
status_class = "text-success"; status_icon = "fas fa-check-circle"; status_text = "ACTIVE"
|
|
elif status == "registered":
|
|
status_class = "text-warning"; status_icon = "fas fa-circle"; status_text = "REGISTERED"
|
|
elif status == "inactive":
|
|
status_class = "text-muted"; status_icon = "fas fa-pause-circle"; status_text = "INACTIVE"
|
|
else:
|
|
status_class = "text-danger"; status_icon = "fas fa-exclamation-circle"; status_text = "UNKNOWN"
|
|
|
|
last_pred = model_data.get("last_prediction", {})
|
|
pred_action = last_pred.get("action") or "NONE"
|
|
pred_confidence = last_pred.get("confidence")
|
|
pred_time = last_pred.get("timestamp") or "N/A"
|
|
|
|
loss_metrics = model_data.get("loss_metrics", {})
|
|
current_loss = loss_metrics.get("current_loss")
|
|
best_loss = loss_metrics.get("best_loss")
|
|
loss_class = (
|
|
"text-success" if (isinstance(current_loss, (int, float)) and current_loss < 0.1)
|
|
else "text-warning" if (isinstance(current_loss, (int, float)) and current_loss < 0.5)
|
|
else "text-danger"
|
|
) if current_loss is not None else "text-muted"
|
|
|
|
timing = model_data.get("timing_metrics", {})
|
|
rate = timing.get("inferences_per_second")
|
|
|
|
# Tooltip title showing checkpoint info (if any)
|
|
title_parts: List[str] = [f"{status_text}"]
|
|
ckpt_name = model_data.get('checkpoint_filename')
|
|
if ckpt_name:
|
|
title_parts.append(f"CKPT: {ckpt_name}")
|
|
header_title = " | ".join(title_parts)
|
|
|
|
return html.Div([
|
|
html.Div([
|
|
html.Div([
|
|
html.I(className=f"{status_icon} me-2 {status_class}"),
|
|
html.Strong(f"{model_name.upper()}", className=status_class, title=header_title),
|
|
html.Span(f" - {status_text}", className=f"{status_class} small ms-1"),
|
|
html.Span(" [CKPT]" if model_data.get("checkpoint_loaded") else (" [FAILED]" if model_data.get("checkpoint_failed") else " [FRESH]"), className=f"small {'text-success' if model_data.get('checkpoint_loaded') else 'text-danger' if model_data.get('checkpoint_failed') else 'text-warning'} ms-1")
|
|
], style={"flex": "1"}),
|
|
html.Div([
|
|
html.Div([
|
|
html.Label("Inf", className="text-muted small me-1", style={"fontSize": "10px"}),
|
|
dcc.Checklist(id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'inference'}, options=[{"label": "", "value": True}], value=[True] if model_data.get('inference_enabled', True) else [], className="form-check-input me-2", style={"transform": "scale(0.7)"})
|
|
], className="d-flex align-items-center me-2"),
|
|
html.Div([
|
|
html.Label("Trn", className="text-muted small me-1", style={"fontSize": "10px"}),
|
|
dcc.Checklist(id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'training'}, options=[{"label": "", "value": True}], value=[True] if model_data.get('training_enabled', True) else [], className="form-check-input", style={"transform": "scale(0.7)"})
|
|
], className="d-flex align-items-center me-2"),
|
|
html.Div([
|
|
html.Label("Route", className="text-muted small me-1", style={"fontSize": "10px"}),
|
|
dcc.Checklist(id={'type': 'model-toggle', 'model': model_name, 'toggle_type': 'routing'}, options=[{"label": "", "value": True}], value=[True] if model_data.get('routing_enabled', True) else [], className="form-check-input", style={"transform": "scale(0.7)"})
|
|
], className="d-flex align-items-center"),
|
|
], className="d-flex")
|
|
], className="d-flex align-items-center mb-2"),
|
|
|
|
html.Div([
|
|
html.Div([
|
|
html.Span("Last: ", className="text-muted small"),
|
|
html.Span(f"{pred_action}", className=f"small fw-bold {'text-success' if pred_action == 'BUY' else 'text-danger' if pred_action == 'SELL' else 'text-warning'}"),
|
|
html.Span(f" ({pred_confidence:.1f}%)" if isinstance(pred_confidence, (int, float)) else "", className="text-muted small"),
|
|
html.Span(f" @ {pred_time}", className="text-muted small")
|
|
], className="mb-1"),
|
|
html.Div([
|
|
html.Span("Loss: ", className="text-muted small"),
|
|
html.Span(f"{current_loss:.4f}" if isinstance(current_loss, (int, float)) else "", className=f"small fw-bold {loss_class}"),
|
|
*([html.Span(" | Best: ", className="text-muted small"), html.Span(f"{best_loss:.4f}", className="text-success small")] if isinstance(best_loss, (int, float)) else [])
|
|
], className="mb-1"),
|
|
html.Div([
|
|
html.Span("Rate: ", className="text-muted small"),
|
|
html.Span(f"{rate:.2f}/s" if isinstance(rate, (int, float)) else "", className="text-info small")
|
|
], className="mb-1"),
|
|
html.Div([
|
|
html.Span("Last Inf: ", className="text-muted small"),
|
|
html.Span(f"{timing.get('last_inference', 'N/A')}", className="text-info small"),
|
|
html.Span(" | Train: ", className="text-muted small"),
|
|
html.Span(f"{timing.get('last_training', 'N/A')}", className="text-warning small")
|
|
], className="mb-1"),
|
|
])
|
|
], className="border rounded p-2 mb-2")
|
|
|
|
def _create_training_status_section(self, status: Dict[str, Any]) -> html.Div:
|
|
if not status or status.get("status") == "error":
|
|
return html.Div([
|
|
html.Hr(),
|
|
html.H6([html.I(className="fas fa-exclamation-triangle me-2 text-danger"), "Training Status Error"], className="mb-2"),
|
|
html.P(f"Error: {status.get('error', 'Unknown')}", className="text-danger small")
|
|
])
|
|
is_enabled = status.get("training_enabled", False)
|
|
return html.Div([
|
|
html.Hr(),
|
|
html.H6([html.I(className="fas fa-brain me-2 text-secondary"), "Training Status"], className="mb-2"),
|
|
html.Div([
|
|
html.Span("Enabled: ", className="text-muted small"),
|
|
html.Span("ON" if is_enabled else "OFF", className=f"small fw-bold {'text-success' if is_enabled else 'text-warning'}"),
|
|
], className="mb-1"),
|
|
])
|
|
|
|
def _create_system_metrics_section(self, metrics: Dict[str, Any]) -> html.Div:
|
|
return html.Div([
|
|
html.Hr(),
|
|
html.H6([html.I(className="fas fa-chart-line me-2 text-primary"), "System Performance"], className="mb-2"),
|
|
html.Div([
|
|
html.Span("Decision Fusion: ", className="text-muted small"),
|
|
html.Span("ON" if metrics.get("decision_fusion_active") else "OFF", className=f"small {'text-success' if metrics.get('decision_fusion_active') else 'text-muted'}"),
|
|
html.Span(" | COB: ", className="text-muted small"),
|
|
html.Span("ON" if metrics.get("cob_integration_active") else "OFF", className=f"small {'text-success' if metrics.get('cob_integration_active') else 'text-muted'}"),
|
|
], className="mb-1"),
|
|
html.Div([
|
|
html.Span("Tracking: ", className="text-muted small"),
|
|
html.Span(f"{metrics.get('symbols_tracking', 0)} symbols", className="text-info small"),
|
|
], className="mb-0"),
|
|
]) |