This commit is contained in:
Dobromir Popov
2025-08-10 21:28:49 +03:00
parent 7289366a35
commit 8738f02d24
5 changed files with 311 additions and 78 deletions

32
.ckpt_count.py Normal file
View File

@@ -0,0 +1,32 @@
import sys, json, os, traceback
sys.path.insert(0, r'F:\projects\gogo2')
res={}
try:
from utils.database_manager import get_database_manager
db=get_database_manager()
def db_count(name):
try:
lst = db.list_checkpoints(name)
return len(lst) if lst is not None else 0
except Exception as e:
print("DB error for %s: %s" % (name, str(e)))
return -1
res.setdefault('db', {})['dqn_agent']=db_count('dqn_agent')
res['db']['enhanced_cnn']=db_count('enhanced_cnn')
except Exception as e:
res['db']={'error': str(e)}
try:
from utils.checkpoint_manager import get_checkpoint_manager
cm=get_checkpoint_manager()
def fs_count(name):
try:
lst = cm.get_all_checkpoints(name)
return len(lst) if lst is not None else 0
except Exception as e:
print("FS error for %s: %s" % (name, str(e)))
return -1
res.setdefault('fs', {})['dqn_agent']=fs_count('dqn_agent')
res['fs']['enhanced_cnn']=fs_count('enhanced_cnn')
except Exception as e:
res['fs']={'error': str(e)}
print(json.dumps(res))

View File

@@ -315,6 +315,16 @@ class TradingOrchestrator:
logger.info(f"Using device: {self.device}")
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
# Canonical → accepted aliases (internal/legacy)
self.model_name_aliases: Dict[str, list] = {
"DQN": ["dqn_agent", "dqn"],
"CNN": ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"],
"EXTREMA": ["extrema_trainer", "extrema"],
"COB": ["cob_rl_model", "cob_rl"],
"DECISION": ["decision_fusion", "decision"],
}
# Configuration - AGGRESSIVE for more training data
self.confidence_threshold = self.config.orchestrator.get(
"confidence_threshold", 0.15
@@ -555,17 +565,16 @@ class TradingOrchestrator:
def _normalize_model_name(self, name: str) -> str:
"""Map various registry/UI names to canonical toggle keys."""
try:
mapping = {
"dqn_agent": "dqn",
"enhanced_cnn": "cnn",
"cnn_model": "cnn",
"decision": "decision_fusion",
"decision_fusion": "decision_fusion",
"cob_rl_model": "cob_rl",
"cob_rl": "cob_rl",
"transformer_model": "transformer",
# Use alias map to unify names to canonical keys
alias_to_canonical = {
**{alias: "DQN" for alias in ["dqn_agent", "dqn"]},
**{alias: "CNN" for alias in ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"]},
**{alias: "EXTREMA" for alias in ["extrema_trainer", "extrema"]},
**{alias: "COB" for alias in ["cob_rl_model", "cob_rl"]},
**{alias: "DECISION" for alias in ["decision_fusion", "decision"]},
"transformer_model": "TRANSFORMER",
}
return mapping.get(name, name)
return alias_to_canonical.get(name, name)
except Exception:
return name
@@ -643,41 +652,44 @@ class TradingOrchestrator:
)
self.rl_agent.to(self.device) # Move DQN agent to the determined device
# Load best checkpoint and capture initial state (using database metadata)
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
checkpoint_loaded = False
if hasattr(self.rl_agent, "load_best_checkpoint"):
try:
self.rl_agent.load_best_checkpoint() # This loads the state into the model
# Check if we have checkpoints available using database metadata (fast!)
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
"dqn_agent"
)
self.rl_agent.load_best_checkpoint() # Load model state if available
# 1) Try DB metadata first
try:
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
except Exception:
checkpoint_metadata = None
if checkpoint_metadata:
self.model_states["dqn"]["initial_loss"] = 0.412
self.model_states["dqn"]["current_loss"] = (
checkpoint_metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["dqn"]["best_loss"] = (
checkpoint_metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["dqn"]["current_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
self.model_states["dqn"]["best_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
self.model_states["dqn"]["checkpoint_loaded"] = True
self.model_states["dqn"][
"checkpoint_filename"
] = checkpoint_metadata.checkpoint_id
self.model_states["dqn"]["checkpoint_filename"] = checkpoint_metadata.checkpoint_id
checkpoint_loaded = True
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(
f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
)
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
else:
# 2) Filesystem fallback via CheckpointManager
try:
from utils.checkpoint_manager import get_checkpoint_manager
cm = get_checkpoint_manager()
result = cm.load_best_checkpoint("dqn_agent")
if result:
model_path, meta = result
# We already loaded model weights via load_best_checkpoint; just record metadata
self.model_states["dqn"]["checkpoint_loaded"] = True
self.model_states["dqn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
checkpoint_loaded = True
logger.info(f"DQN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
except Exception:
pass
except Exception as e:
logger.warning(
f"Error loading DQN checkpoint (likely dimension mismatch): {e}"
)
logger.info(
"DQN will start fresh due to checkpoint incompatibility"
)
# Reset the agent to handle dimension mismatch
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
logger.info("DQN will start fresh due to checkpoint incompatibility")
checkpoint_loaded = False
if not checkpoint_loaded:
@@ -712,7 +724,7 @@ class TradingOrchestrator:
self.cnn_model.parameters(), lr=0.001
) # Initialize optimizer for CNN
# Load best checkpoint and capture initial state (using database metadata)
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
checkpoint_loaded = False
try:
db_manager = get_database_manager()
@@ -738,6 +750,19 @@ class TradingOrchestrator:
)
except Exception as e:
logger.warning(f"Error loading CNN checkpoint: {e}")
# Filesystem fallback
try:
from utils.checkpoint_manager import get_checkpoint_manager
cm = get_checkpoint_manager()
result = cm.load_best_checkpoint("enhanced_cnn")
if result:
model_path, meta = result
self.model_states["cnn"]["checkpoint_loaded"] = True
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
checkpoint_loaded = True
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
except Exception:
pass
if not checkpoint_loaded:
# New model - no synthetic data
@@ -7951,28 +7976,29 @@ class TradingOrchestrator:
self.checkpoint_manager = get_checkpoint_manager()
# Initialize model states dictionary to track performance
self.model_states = {
"dqn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cnn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cob_rl": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"extrema": {
"initial_loss": None,
# Initialize model states dictionary to track performance (only if not already initialized)
if not hasattr(self, 'model_states') or self.model_states is None:
self.model_states = {
"dqn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cnn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cob_rl": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"extrema": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""
Debug script to test checkpoint loading
"""
import sys
import os
sys.path.append('.')
def test_checkpoint_loading():
"""Test checkpoint loading for all models"""
print("=== Testing Checkpoint Loading ===")
try:
from utils.database_manager import get_database_manager
from utils.checkpoint_manager import load_best_checkpoint
db = get_database_manager()
# Test models that should have checkpoints
models = ['dqn_agent', 'enhanced_cnn', 'cob_rl_model', 'extrema_trainer']
for model in models:
print(f"\n--- Testing {model} ---")
# Check database
checkpoints = db.list_checkpoints(model)
print(f"DB checkpoints: {len(checkpoints)}")
if checkpoints:
best = db.get_best_checkpoint_metadata(model)
if best:
print(f"Best checkpoint: {best.checkpoint_id}")
print(f"File path: {best.file_path}")
print(f"File exists: {os.path.exists(best.file_path)}")
print(f"Active: {best.is_active}")
print(f"Performance metrics: {best.performance_metrics}")
else:
print("No best checkpoint found in DB")
# Test filesystem fallback
result = load_best_checkpoint(model)
if result:
file_path, metadata = result
print(f"Filesystem fallback: {file_path}")
print(f"File exists: {os.path.exists(file_path)}")
print(f"Metadata: {getattr(metadata, 'checkpoint_id', 'N/A')}")
else:
print("No filesystem fallback found")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
def test_model_initialization():
"""Test model initialization with checkpoint loading"""
print("\n=== Testing Model Initialization ===")
try:
# Test DQN Agent initialization
print("\n--- Testing DQN Agent ---")
from NN.models.dqn_agent import DQNAgent
# Create a minimal DQN agent
agent = DQNAgent(
state_shape=(10,), # Simple state shape
n_actions=3,
model_name="dqn_agent",
enable_checkpoints=True
)
print(f"Agent created with checkpoints enabled: {agent.enable_checkpoints}")
print(f"Model name: {agent.model_name}")
# Try to load checkpoint manually
print("Attempting manual checkpoint load...")
agent.load_best_checkpoint()
print("Manual checkpoint load completed")
except Exception as e:
print(f"Error in model initialization: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_checkpoint_loading()
test_model_initialization()

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python3
"""
Force refresh dashboard model states to show correct DQN status
"""
import sys
import os
import requests
import time
sys.path.append('.')
def force_refresh_dashboard():
"""Force refresh the dashboard to show correct model states"""
print("=== Forcing Dashboard Refresh ===")
# Try to hit the dashboard endpoint to force a refresh
dashboard_url = "http://localhost:8050"
try:
print(f"Attempting to refresh dashboard at {dashboard_url}...")
# Try to access the main dashboard page
response = requests.get(dashboard_url, timeout=5)
if response.status_code == 200:
print("✅ Dashboard is accessible and responding")
else:
print(f"⚠️ Dashboard responded with status code: {response.status_code}")
# Try to access the model states API endpoint if it exists
try:
api_response = requests.get(f"{dashboard_url}/api/model-states", timeout=5)
if api_response.status_code == 200:
print("✅ Model states API is accessible")
else:
print(f"⚠️ Model states API responded with: {api_response.status_code}")
except:
print(" No model states API endpoint found (this is normal)")
except requests.exceptions.ConnectionError:
print("❌ Dashboard is not running or not accessible")
print("Please start the dashboard with: python run_clean_dashboard.py")
except Exception as e:
print(f"❌ Error accessing dashboard: {e}")
# Also verify the model states are correct
print("\n=== Verifying Model States ===")
try:
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
dp = DataProvider()
orch = TradingOrchestrator(data_provider=dp)
states = orch.get_model_states()
if states and 'dqn' in states:
dqn_state = states['dqn']
checkpoint_loaded = dqn_state.get('checkpoint_loaded', False)
checkpoint_filename = dqn_state.get('checkpoint_filename', 'None')
print(f"✅ DQN checkpoint_loaded: {checkpoint_loaded}")
print(f"✅ DQN checkpoint_filename: {checkpoint_filename}")
if checkpoint_loaded:
print("🎯 DQN should show as ACTIVE with [CKPT], not FRESH")
else:
print("⚠️ DQN checkpoint not loaded - will show as FRESH")
else:
print("❌ No DQN state found")
except Exception as e:
print(f"❌ Error verifying model states: {e}")
if __name__ == "__main__":
force_refresh_dashboard()

View File

@@ -26,8 +26,6 @@ class ModelsTrainingPanel:
data = self._gather_data()
content: List[html.Div] = []
content.append(self._create_header())
content.append(self._create_models_section(data.get("models", {})))
if data.get("training_status"):
@@ -55,13 +53,21 @@ class ModelsTrainingPanel:
if isinstance(stats, dict):
stats_summary = stats
# Model states (for best_loss and checkpoint flags)
model_states: Dict[str, Dict[str, Any]] = getattr(self.orchestrator, "model_states", {}) or {}
# Model states (for best_loss and checkpoint flags) - use get_model_states() for updated checkpoint info
model_states: Dict[str, Dict[str, Any]] = {}
if hasattr(self.orchestrator, "get_model_states"):
try:
model_states = self.orchestrator.get_model_states() or {}
except Exception as e:
logger.debug(f"Error getting model states: {e}")
model_states = getattr(self.orchestrator, "model_states", {}) or {}
# Build models block from stats_summary
for model_name, s in stats_summary.items():
model_info: Dict[str, Any] = {}
try:
# Use actual model name - no mapping needed
display_name = model_name.upper()
# Status: active if we have any inference info
total_inf = int(s.get("total_inferences", 0) or 0)
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
@@ -84,22 +90,35 @@ class ModelsTrainingPanel:
"last_training": s.get("last_training_time"),
}
# Checkpoint flags from orchestrator.model_states if present
# Checkpoint flags from orchestrator.model_states - map external names to internal keys
ckpt_loaded = False
ckpt_failed = False
if model_states and model_name in model_states:
ckpt_loaded = bool(model_states[model_name].get("checkpoint_loaded", False))
ckpt_failed = bool(model_states[model_name].get("checkpoint_failed", False))
checkpoint_filename = None
# Map external model names to internal model state keys
state_key_map = {
"dqn_agent": "dqn",
"enhanced_cnn": "cnn",
"extrema_trainer": "extrema_trainer",
"decision": "decision",
"cob_rl_model": "cob_rl"
}
state_key = state_key_map.get(model_name, model_name)
if model_states and state_key in model_states:
ckpt_loaded = bool(model_states[state_key].get("checkpoint_loaded", False))
ckpt_failed = bool(model_states[state_key].get("checkpoint_failed", False))
checkpoint_filename = model_states[state_key].get("checkpoint_filename")
# If model started fresh, mark as stored after first successful autosave/best_loss update
if not ckpt_loaded:
cur = model_states[model_name].get("current_loss")
best = model_states[model_name].get("best_loss")
cur = model_states[state_key].get("current_loss")
best = model_states[state_key].get("best_loss")
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
ckpt_loaded = True
checkpoint_filename = model_states[model_name].get("checkpoint_filename")
model_info = {
"name": model_name,
"name": display_name,
"status": status,
"parameters": 0, # unknown; do not synthesize
"last_prediction": {
@@ -112,7 +131,7 @@ class ModelsTrainingPanel:
"routing_enabled": True,
"checkpoint_loaded": ckpt_loaded,
"checkpoint_failed": ckpt_failed,
"checkpoint_filename": checkpoint_filename if model_states and model_name in model_states else None,
"checkpoint_filename": checkpoint_filename,
"loss_metrics": loss_metrics,
"timing_metrics": timing_metrics,
"signal_stats": {},
@@ -145,12 +164,6 @@ class ModelsTrainingPanel:
logger.error(f"Error gathering panel data: {e}")
return result
def _create_header(self) -> html.Div:
return html.Div([
html.H6([html.I(className="fas fa-brain me-2 text-primary"), "Models & Training Progress"], className="mb-2"),
html.Button([html.I(className="fas fa-sync-alt me-1"), "Refresh"], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2")
], className="d-flex justify-content-between align-items-start")
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
cards = [self._create_model_card(name, info) for name, info in models.items()]
return html.Div([