fixes
This commit is contained in:
32
.ckpt_count.py
Normal file
32
.ckpt_count.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import sys, json, os, traceback
|
||||||
|
sys.path.insert(0, r'F:\projects\gogo2')
|
||||||
|
res={}
|
||||||
|
try:
|
||||||
|
from utils.database_manager import get_database_manager
|
||||||
|
db=get_database_manager()
|
||||||
|
def db_count(name):
|
||||||
|
try:
|
||||||
|
lst = db.list_checkpoints(name)
|
||||||
|
return len(lst) if lst is not None else 0
|
||||||
|
except Exception as e:
|
||||||
|
print("DB error for %s: %s" % (name, str(e)))
|
||||||
|
return -1
|
||||||
|
res.setdefault('db', {})['dqn_agent']=db_count('dqn_agent')
|
||||||
|
res['db']['enhanced_cnn']=db_count('enhanced_cnn')
|
||||||
|
except Exception as e:
|
||||||
|
res['db']={'error': str(e)}
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
cm=get_checkpoint_manager()
|
||||||
|
def fs_count(name):
|
||||||
|
try:
|
||||||
|
lst = cm.get_all_checkpoints(name)
|
||||||
|
return len(lst) if lst is not None else 0
|
||||||
|
except Exception as e:
|
||||||
|
print("FS error for %s: %s" % (name, str(e)))
|
||||||
|
return -1
|
||||||
|
res.setdefault('fs', {})['dqn_agent']=fs_count('dqn_agent')
|
||||||
|
res['fs']['enhanced_cnn']=fs_count('enhanced_cnn')
|
||||||
|
except Exception as e:
|
||||||
|
res['fs']={'error': str(e)}
|
||||||
|
print(json.dumps(res))
|
@@ -315,6 +315,16 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
logger.info(f"Using device: {self.device}")
|
logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
|
||||||
|
# Canonical → accepted aliases (internal/legacy)
|
||||||
|
self.model_name_aliases: Dict[str, list] = {
|
||||||
|
"DQN": ["dqn_agent", "dqn"],
|
||||||
|
"CNN": ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"],
|
||||||
|
"EXTREMA": ["extrema_trainer", "extrema"],
|
||||||
|
"COB": ["cob_rl_model", "cob_rl"],
|
||||||
|
"DECISION": ["decision_fusion", "decision"],
|
||||||
|
}
|
||||||
|
|
||||||
# Configuration - AGGRESSIVE for more training data
|
# Configuration - AGGRESSIVE for more training data
|
||||||
self.confidence_threshold = self.config.orchestrator.get(
|
self.confidence_threshold = self.config.orchestrator.get(
|
||||||
"confidence_threshold", 0.15
|
"confidence_threshold", 0.15
|
||||||
@@ -555,17 +565,16 @@ class TradingOrchestrator:
|
|||||||
def _normalize_model_name(self, name: str) -> str:
|
def _normalize_model_name(self, name: str) -> str:
|
||||||
"""Map various registry/UI names to canonical toggle keys."""
|
"""Map various registry/UI names to canonical toggle keys."""
|
||||||
try:
|
try:
|
||||||
mapping = {
|
# Use alias map to unify names to canonical keys
|
||||||
"dqn_agent": "dqn",
|
alias_to_canonical = {
|
||||||
"enhanced_cnn": "cnn",
|
**{alias: "DQN" for alias in ["dqn_agent", "dqn"]},
|
||||||
"cnn_model": "cnn",
|
**{alias: "CNN" for alias in ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"]},
|
||||||
"decision": "decision_fusion",
|
**{alias: "EXTREMA" for alias in ["extrema_trainer", "extrema"]},
|
||||||
"decision_fusion": "decision_fusion",
|
**{alias: "COB" for alias in ["cob_rl_model", "cob_rl"]},
|
||||||
"cob_rl_model": "cob_rl",
|
**{alias: "DECISION" for alias in ["decision_fusion", "decision"]},
|
||||||
"cob_rl": "cob_rl",
|
"transformer_model": "TRANSFORMER",
|
||||||
"transformer_model": "transformer",
|
|
||||||
}
|
}
|
||||||
return mapping.get(name, name)
|
return alias_to_canonical.get(name, name)
|
||||||
except Exception:
|
except Exception:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
@@ -643,41 +652,44 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||||
|
|
||||||
# Load best checkpoint and capture initial state (using database metadata)
|
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
|
||||||
checkpoint_loaded = False
|
checkpoint_loaded = False
|
||||||
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
||||||
try:
|
try:
|
||||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
self.rl_agent.load_best_checkpoint() # Load model state if available
|
||||||
# Check if we have checkpoints available using database metadata (fast!)
|
# 1) Try DB metadata first
|
||||||
db_manager = get_database_manager()
|
try:
|
||||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
db_manager = get_database_manager()
|
||||||
"dqn_agent"
|
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||||
)
|
except Exception:
|
||||||
|
checkpoint_metadata = None
|
||||||
if checkpoint_metadata:
|
if checkpoint_metadata:
|
||||||
self.model_states["dqn"]["initial_loss"] = 0.412
|
self.model_states["dqn"]["initial_loss"] = 0.412
|
||||||
self.model_states["dqn"]["current_loss"] = (
|
self.model_states["dqn"]["current_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
self.model_states["dqn"]["best_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||||
)
|
|
||||||
self.model_states["dqn"]["best_loss"] = (
|
|
||||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
|
||||||
)
|
|
||||||
self.model_states["dqn"]["checkpoint_loaded"] = True
|
self.model_states["dqn"]["checkpoint_loaded"] = True
|
||||||
self.model_states["dqn"][
|
self.model_states["dqn"]["checkpoint_filename"] = checkpoint_metadata.checkpoint_id
|
||||||
"checkpoint_filename"
|
|
||||||
] = checkpoint_metadata.checkpoint_id
|
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||||
logger.info(
|
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||||
f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
|
else:
|
||||||
)
|
# 2) Filesystem fallback via CheckpointManager
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
cm = get_checkpoint_manager()
|
||||||
|
result = cm.load_best_checkpoint("dqn_agent")
|
||||||
|
if result:
|
||||||
|
model_path, meta = result
|
||||||
|
# We already loaded model weights via load_best_checkpoint; just record metadata
|
||||||
|
self.model_states["dqn"]["checkpoint_loaded"] = True
|
||||||
|
self.model_states["dqn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||||
|
checkpoint_loaded = True
|
||||||
|
logger.info(f"DQN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||||
f"Error loading DQN checkpoint (likely dimension mismatch): {e}"
|
logger.info("DQN will start fresh due to checkpoint incompatibility")
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"DQN will start fresh due to checkpoint incompatibility"
|
|
||||||
)
|
|
||||||
# Reset the agent to handle dimension mismatch
|
|
||||||
checkpoint_loaded = False
|
checkpoint_loaded = False
|
||||||
|
|
||||||
if not checkpoint_loaded:
|
if not checkpoint_loaded:
|
||||||
@@ -712,7 +724,7 @@ class TradingOrchestrator:
|
|||||||
self.cnn_model.parameters(), lr=0.001
|
self.cnn_model.parameters(), lr=0.001
|
||||||
) # Initialize optimizer for CNN
|
) # Initialize optimizer for CNN
|
||||||
|
|
||||||
# Load best checkpoint and capture initial state (using database metadata)
|
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
|
||||||
checkpoint_loaded = False
|
checkpoint_loaded = False
|
||||||
try:
|
try:
|
||||||
db_manager = get_database_manager()
|
db_manager = get_database_manager()
|
||||||
@@ -738,6 +750,19 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||||
|
# Filesystem fallback
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
cm = get_checkpoint_manager()
|
||||||
|
result = cm.load_best_checkpoint("enhanced_cnn")
|
||||||
|
if result:
|
||||||
|
model_path, meta = result
|
||||||
|
self.model_states["cnn"]["checkpoint_loaded"] = True
|
||||||
|
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||||
|
checkpoint_loaded = True
|
||||||
|
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if not checkpoint_loaded:
|
if not checkpoint_loaded:
|
||||||
# New model - no synthetic data
|
# New model - no synthetic data
|
||||||
@@ -7951,28 +7976,29 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
self.checkpoint_manager = get_checkpoint_manager()
|
self.checkpoint_manager = get_checkpoint_manager()
|
||||||
|
|
||||||
# Initialize model states dictionary to track performance
|
# Initialize model states dictionary to track performance (only if not already initialized)
|
||||||
self.model_states = {
|
if not hasattr(self, 'model_states') or self.model_states is None:
|
||||||
"dqn": {
|
self.model_states = {
|
||||||
"initial_loss": None,
|
"dqn": {
|
||||||
"current_loss": None,
|
"initial_loss": None,
|
||||||
"best_loss": float("inf"),
|
"current_loss": None,
|
||||||
"checkpoint_loaded": False,
|
"best_loss": float("inf"),
|
||||||
},
|
"checkpoint_loaded": False,
|
||||||
"cnn": {
|
},
|
||||||
"initial_loss": None,
|
"cnn": {
|
||||||
"current_loss": None,
|
"initial_loss": None,
|
||||||
"best_loss": float("inf"),
|
"current_loss": None,
|
||||||
"checkpoint_loaded": False,
|
"best_loss": float("inf"),
|
||||||
},
|
"checkpoint_loaded": False,
|
||||||
"cob_rl": {
|
},
|
||||||
"initial_loss": None,
|
"cob_rl": {
|
||||||
"current_loss": None,
|
"initial_loss": None,
|
||||||
"best_loss": float("inf"),
|
"current_loss": None,
|
||||||
"checkpoint_loaded": False,
|
"best_loss": float("inf"),
|
||||||
},
|
"checkpoint_loaded": False,
|
||||||
"extrema": {
|
},
|
||||||
"initial_loss": None,
|
"extrema": {
|
||||||
|
"initial_loss": None,
|
||||||
"current_loss": None,
|
"current_loss": None,
|
||||||
"best_loss": float("inf"),
|
"best_loss": float("inf"),
|
||||||
"checkpoint_loaded": False,
|
"checkpoint_loaded": False,
|
||||||
|
88
debug_checkpoint_loading.py
Normal file
88
debug_checkpoint_loading.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Debug script to test checkpoint loading
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append('.')
|
||||||
|
|
||||||
|
def test_checkpoint_loading():
|
||||||
|
"""Test checkpoint loading for all models"""
|
||||||
|
print("=== Testing Checkpoint Loading ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from utils.database_manager import get_database_manager
|
||||||
|
from utils.checkpoint_manager import load_best_checkpoint
|
||||||
|
|
||||||
|
db = get_database_manager()
|
||||||
|
|
||||||
|
# Test models that should have checkpoints
|
||||||
|
models = ['dqn_agent', 'enhanced_cnn', 'cob_rl_model', 'extrema_trainer']
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
print(f"\n--- Testing {model} ---")
|
||||||
|
|
||||||
|
# Check database
|
||||||
|
checkpoints = db.list_checkpoints(model)
|
||||||
|
print(f"DB checkpoints: {len(checkpoints)}")
|
||||||
|
|
||||||
|
if checkpoints:
|
||||||
|
best = db.get_best_checkpoint_metadata(model)
|
||||||
|
if best:
|
||||||
|
print(f"Best checkpoint: {best.checkpoint_id}")
|
||||||
|
print(f"File path: {best.file_path}")
|
||||||
|
print(f"File exists: {os.path.exists(best.file_path)}")
|
||||||
|
print(f"Active: {best.is_active}")
|
||||||
|
print(f"Performance metrics: {best.performance_metrics}")
|
||||||
|
else:
|
||||||
|
print("No best checkpoint found in DB")
|
||||||
|
|
||||||
|
# Test filesystem fallback
|
||||||
|
result = load_best_checkpoint(model)
|
||||||
|
if result:
|
||||||
|
file_path, metadata = result
|
||||||
|
print(f"Filesystem fallback: {file_path}")
|
||||||
|
print(f"File exists: {os.path.exists(file_path)}")
|
||||||
|
print(f"Metadata: {getattr(metadata, 'checkpoint_id', 'N/A')}")
|
||||||
|
else:
|
||||||
|
print("No filesystem fallback found")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def test_model_initialization():
|
||||||
|
"""Test model initialization with checkpoint loading"""
|
||||||
|
print("\n=== Testing Model Initialization ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test DQN Agent initialization
|
||||||
|
print("\n--- Testing DQN Agent ---")
|
||||||
|
from NN.models.dqn_agent import DQNAgent
|
||||||
|
|
||||||
|
# Create a minimal DQN agent
|
||||||
|
agent = DQNAgent(
|
||||||
|
state_shape=(10,), # Simple state shape
|
||||||
|
n_actions=3,
|
||||||
|
model_name="dqn_agent",
|
||||||
|
enable_checkpoints=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Agent created with checkpoints enabled: {agent.enable_checkpoints}")
|
||||||
|
print(f"Model name: {agent.model_name}")
|
||||||
|
|
||||||
|
# Try to load checkpoint manually
|
||||||
|
print("Attempting manual checkpoint load...")
|
||||||
|
agent.load_best_checkpoint()
|
||||||
|
print("Manual checkpoint load completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in model initialization: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_checkpoint_loading()
|
||||||
|
test_model_initialization()
|
74
force_refresh_dashboard.py
Normal file
74
force_refresh_dashboard.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Force refresh dashboard model states to show correct DQN status
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
sys.path.append('.')
|
||||||
|
|
||||||
|
def force_refresh_dashboard():
|
||||||
|
"""Force refresh the dashboard to show correct model states"""
|
||||||
|
print("=== Forcing Dashboard Refresh ===")
|
||||||
|
|
||||||
|
# Try to hit the dashboard endpoint to force a refresh
|
||||||
|
dashboard_url = "http://localhost:8050"
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Attempting to refresh dashboard at {dashboard_url}...")
|
||||||
|
|
||||||
|
# Try to access the main dashboard page
|
||||||
|
response = requests.get(dashboard_url, timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("✅ Dashboard is accessible and responding")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Dashboard responded with status code: {response.status_code}")
|
||||||
|
|
||||||
|
# Try to access the model states API endpoint if it exists
|
||||||
|
try:
|
||||||
|
api_response = requests.get(f"{dashboard_url}/api/model-states", timeout=5)
|
||||||
|
if api_response.status_code == 200:
|
||||||
|
print("✅ Model states API is accessible")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Model states API responded with: {api_response.status_code}")
|
||||||
|
except:
|
||||||
|
print("ℹ️ No model states API endpoint found (this is normal)")
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
print("❌ Dashboard is not running or not accessible")
|
||||||
|
print("Please start the dashboard with: python run_clean_dashboard.py")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error accessing dashboard: {e}")
|
||||||
|
|
||||||
|
# Also verify the model states are correct
|
||||||
|
print("\n=== Verifying Model States ===")
|
||||||
|
try:
|
||||||
|
from core.orchestrator import TradingOrchestrator
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
|
||||||
|
dp = DataProvider()
|
||||||
|
orch = TradingOrchestrator(data_provider=dp)
|
||||||
|
|
||||||
|
states = orch.get_model_states()
|
||||||
|
if states and 'dqn' in states:
|
||||||
|
dqn_state = states['dqn']
|
||||||
|
checkpoint_loaded = dqn_state.get('checkpoint_loaded', False)
|
||||||
|
checkpoint_filename = dqn_state.get('checkpoint_filename', 'None')
|
||||||
|
|
||||||
|
print(f"✅ DQN checkpoint_loaded: {checkpoint_loaded}")
|
||||||
|
print(f"✅ DQN checkpoint_filename: {checkpoint_filename}")
|
||||||
|
|
||||||
|
if checkpoint_loaded:
|
||||||
|
print("🎯 DQN should show as ACTIVE with [CKPT], not FRESH")
|
||||||
|
else:
|
||||||
|
print("⚠️ DQN checkpoint not loaded - will show as FRESH")
|
||||||
|
else:
|
||||||
|
print("❌ No DQN state found")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error verifying model states: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
force_refresh_dashboard()
|
@@ -26,8 +26,6 @@ class ModelsTrainingPanel:
|
|||||||
data = self._gather_data()
|
data = self._gather_data()
|
||||||
|
|
||||||
content: List[html.Div] = []
|
content: List[html.Div] = []
|
||||||
content.append(self._create_header())
|
|
||||||
|
|
||||||
content.append(self._create_models_section(data.get("models", {})))
|
content.append(self._create_models_section(data.get("models", {})))
|
||||||
|
|
||||||
if data.get("training_status"):
|
if data.get("training_status"):
|
||||||
@@ -55,13 +53,21 @@ class ModelsTrainingPanel:
|
|||||||
if isinstance(stats, dict):
|
if isinstance(stats, dict):
|
||||||
stats_summary = stats
|
stats_summary = stats
|
||||||
|
|
||||||
# Model states (for best_loss and checkpoint flags)
|
# Model states (for best_loss and checkpoint flags) - use get_model_states() for updated checkpoint info
|
||||||
model_states: Dict[str, Dict[str, Any]] = getattr(self.orchestrator, "model_states", {}) or {}
|
model_states: Dict[str, Dict[str, Any]] = {}
|
||||||
|
if hasattr(self.orchestrator, "get_model_states"):
|
||||||
|
try:
|
||||||
|
model_states = self.orchestrator.get_model_states() or {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting model states: {e}")
|
||||||
|
model_states = getattr(self.orchestrator, "model_states", {}) or {}
|
||||||
|
|
||||||
# Build models block from stats_summary
|
# Build models block from stats_summary
|
||||||
for model_name, s in stats_summary.items():
|
for model_name, s in stats_summary.items():
|
||||||
model_info: Dict[str, Any] = {}
|
model_info: Dict[str, Any] = {}
|
||||||
try:
|
try:
|
||||||
|
# Use actual model name - no mapping needed
|
||||||
|
display_name = model_name.upper()
|
||||||
# Status: active if we have any inference info
|
# Status: active if we have any inference info
|
||||||
total_inf = int(s.get("total_inferences", 0) or 0)
|
total_inf = int(s.get("total_inferences", 0) or 0)
|
||||||
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
|
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
|
||||||
@@ -84,22 +90,35 @@ class ModelsTrainingPanel:
|
|||||||
"last_training": s.get("last_training_time"),
|
"last_training": s.get("last_training_time"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Checkpoint flags from orchestrator.model_states if present
|
# Checkpoint flags from orchestrator.model_states - map external names to internal keys
|
||||||
ckpt_loaded = False
|
ckpt_loaded = False
|
||||||
ckpt_failed = False
|
ckpt_failed = False
|
||||||
if model_states and model_name in model_states:
|
checkpoint_filename = None
|
||||||
ckpt_loaded = bool(model_states[model_name].get("checkpoint_loaded", False))
|
|
||||||
ckpt_failed = bool(model_states[model_name].get("checkpoint_failed", False))
|
# Map external model names to internal model state keys
|
||||||
|
state_key_map = {
|
||||||
|
"dqn_agent": "dqn",
|
||||||
|
"enhanced_cnn": "cnn",
|
||||||
|
"extrema_trainer": "extrema_trainer",
|
||||||
|
"decision": "decision",
|
||||||
|
"cob_rl_model": "cob_rl"
|
||||||
|
}
|
||||||
|
state_key = state_key_map.get(model_name, model_name)
|
||||||
|
|
||||||
|
if model_states and state_key in model_states:
|
||||||
|
ckpt_loaded = bool(model_states[state_key].get("checkpoint_loaded", False))
|
||||||
|
ckpt_failed = bool(model_states[state_key].get("checkpoint_failed", False))
|
||||||
|
checkpoint_filename = model_states[state_key].get("checkpoint_filename")
|
||||||
|
|
||||||
# If model started fresh, mark as stored after first successful autosave/best_loss update
|
# If model started fresh, mark as stored after first successful autosave/best_loss update
|
||||||
if not ckpt_loaded:
|
if not ckpt_loaded:
|
||||||
cur = model_states[model_name].get("current_loss")
|
cur = model_states[state_key].get("current_loss")
|
||||||
best = model_states[model_name].get("best_loss")
|
best = model_states[state_key].get("best_loss")
|
||||||
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
|
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
|
||||||
ckpt_loaded = True
|
ckpt_loaded = True
|
||||||
checkpoint_filename = model_states[model_name].get("checkpoint_filename")
|
|
||||||
|
|
||||||
model_info = {
|
model_info = {
|
||||||
"name": model_name,
|
"name": display_name,
|
||||||
"status": status,
|
"status": status,
|
||||||
"parameters": 0, # unknown; do not synthesize
|
"parameters": 0, # unknown; do not synthesize
|
||||||
"last_prediction": {
|
"last_prediction": {
|
||||||
@@ -112,7 +131,7 @@ class ModelsTrainingPanel:
|
|||||||
"routing_enabled": True,
|
"routing_enabled": True,
|
||||||
"checkpoint_loaded": ckpt_loaded,
|
"checkpoint_loaded": ckpt_loaded,
|
||||||
"checkpoint_failed": ckpt_failed,
|
"checkpoint_failed": ckpt_failed,
|
||||||
"checkpoint_filename": checkpoint_filename if model_states and model_name in model_states else None,
|
"checkpoint_filename": checkpoint_filename,
|
||||||
"loss_metrics": loss_metrics,
|
"loss_metrics": loss_metrics,
|
||||||
"timing_metrics": timing_metrics,
|
"timing_metrics": timing_metrics,
|
||||||
"signal_stats": {},
|
"signal_stats": {},
|
||||||
@@ -145,12 +164,6 @@ class ModelsTrainingPanel:
|
|||||||
logger.error(f"Error gathering panel data: {e}")
|
logger.error(f"Error gathering panel data: {e}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _create_header(self) -> html.Div:
|
|
||||||
return html.Div([
|
|
||||||
html.H6([html.I(className="fas fa-brain me-2 text-primary"), "Models & Training Progress"], className="mb-2"),
|
|
||||||
html.Button([html.I(className="fas fa-sync-alt me-1"), "Refresh"], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2")
|
|
||||||
], className="d-flex justify-content-between align-items-start")
|
|
||||||
|
|
||||||
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
|
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
|
||||||
cards = [self._create_model_card(name, info) for name, info in models.items()]
|
cards = [self._create_model_card(name, info) for name, info in models.items()]
|
||||||
return html.Div([
|
return html.Div([
|
||||||
|
Reference in New Issue
Block a user