fixes
This commit is contained in:
32
.ckpt_count.py
Normal file
32
.ckpt_count.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sys, json, os, traceback
|
||||
sys.path.insert(0, r'F:\projects\gogo2')
|
||||
res={}
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
db=get_database_manager()
|
||||
def db_count(name):
|
||||
try:
|
||||
lst = db.list_checkpoints(name)
|
||||
return len(lst) if lst is not None else 0
|
||||
except Exception as e:
|
||||
print("DB error for %s: %s" % (name, str(e)))
|
||||
return -1
|
||||
res.setdefault('db', {})['dqn_agent']=db_count('dqn_agent')
|
||||
res['db']['enhanced_cnn']=db_count('enhanced_cnn')
|
||||
except Exception as e:
|
||||
res['db']={'error': str(e)}
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
cm=get_checkpoint_manager()
|
||||
def fs_count(name):
|
||||
try:
|
||||
lst = cm.get_all_checkpoints(name)
|
||||
return len(lst) if lst is not None else 0
|
||||
except Exception as e:
|
||||
print("FS error for %s: %s" % (name, str(e)))
|
||||
return -1
|
||||
res.setdefault('fs', {})['dqn_agent']=fs_count('dqn_agent')
|
||||
res['fs']['enhanced_cnn']=fs_count('enhanced_cnn')
|
||||
except Exception as e:
|
||||
res['fs']={'error': str(e)}
|
||||
print(json.dumps(res))
|
@@ -315,6 +315,16 @@ class TradingOrchestrator:
|
||||
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
|
||||
# Canonical → accepted aliases (internal/legacy)
|
||||
self.model_name_aliases: Dict[str, list] = {
|
||||
"DQN": ["dqn_agent", "dqn"],
|
||||
"CNN": ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"],
|
||||
"EXTREMA": ["extrema_trainer", "extrema"],
|
||||
"COB": ["cob_rl_model", "cob_rl"],
|
||||
"DECISION": ["decision_fusion", "decision"],
|
||||
}
|
||||
|
||||
# Configuration - AGGRESSIVE for more training data
|
||||
self.confidence_threshold = self.config.orchestrator.get(
|
||||
"confidence_threshold", 0.15
|
||||
@@ -555,17 +565,16 @@ class TradingOrchestrator:
|
||||
def _normalize_model_name(self, name: str) -> str:
|
||||
"""Map various registry/UI names to canonical toggle keys."""
|
||||
try:
|
||||
mapping = {
|
||||
"dqn_agent": "dqn",
|
||||
"enhanced_cnn": "cnn",
|
||||
"cnn_model": "cnn",
|
||||
"decision": "decision_fusion",
|
||||
"decision_fusion": "decision_fusion",
|
||||
"cob_rl_model": "cob_rl",
|
||||
"cob_rl": "cob_rl",
|
||||
"transformer_model": "transformer",
|
||||
# Use alias map to unify names to canonical keys
|
||||
alias_to_canonical = {
|
||||
**{alias: "DQN" for alias in ["dqn_agent", "dqn"]},
|
||||
**{alias: "CNN" for alias in ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"]},
|
||||
**{alias: "EXTREMA" for alias in ["extrema_trainer", "extrema"]},
|
||||
**{alias: "COB" for alias in ["cob_rl_model", "cob_rl"]},
|
||||
**{alias: "DECISION" for alias in ["decision_fusion", "decision"]},
|
||||
"transformer_model": "TRANSFORMER",
|
||||
}
|
||||
return mapping.get(name, name)
|
||||
return alias_to_canonical.get(name, name)
|
||||
except Exception:
|
||||
return name
|
||||
|
||||
@@ -643,41 +652,44 @@ class TradingOrchestrator:
|
||||
)
|
||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||||
# Check if we have checkpoints available using database metadata (fast!)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
||||
"dqn_agent"
|
||||
)
|
||||
self.rl_agent.load_best_checkpoint() # Load model state if available
|
||||
# 1) Try DB metadata first
|
||||
try:
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||
except Exception:
|
||||
checkpoint_metadata = None
|
||||
if checkpoint_metadata:
|
||||
self.model_states["dqn"]["initial_loss"] = 0.412
|
||||
self.model_states["dqn"]["current_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["dqn"]["best_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["dqn"]["current_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
self.model_states["dqn"]["best_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
self.model_states["dqn"]["checkpoint_loaded"] = True
|
||||
self.model_states["dqn"][
|
||||
"checkpoint_filename"
|
||||
] = checkpoint_metadata.checkpoint_id
|
||||
self.model_states["dqn"]["checkpoint_filename"] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(
|
||||
f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
else:
|
||||
# 2) Filesystem fallback via CheckpointManager
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
cm = get_checkpoint_manager()
|
||||
result = cm.load_best_checkpoint("dqn_agent")
|
||||
if result:
|
||||
model_path, meta = result
|
||||
# We already loaded model weights via load_best_checkpoint; just record metadata
|
||||
self.model_states["dqn"]["checkpoint_loaded"] = True
|
||||
self.model_states["dqn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"DQN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error loading DQN checkpoint (likely dimension mismatch): {e}"
|
||||
)
|
||||
logger.info(
|
||||
"DQN will start fresh due to checkpoint incompatibility"
|
||||
)
|
||||
# Reset the agent to handle dimension mismatch
|
||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||
logger.info("DQN will start fresh due to checkpoint incompatibility")
|
||||
checkpoint_loaded = False
|
||||
|
||||
if not checkpoint_loaded:
|
||||
@@ -712,7 +724,7 @@ class TradingOrchestrator:
|
||||
self.cnn_model.parameters(), lr=0.001
|
||||
) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
db_manager = get_database_manager()
|
||||
@@ -738,6 +750,19 @@ class TradingOrchestrator:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
# Filesystem fallback
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
cm = get_checkpoint_manager()
|
||||
result = cm.load_best_checkpoint("enhanced_cnn")
|
||||
if result:
|
||||
model_path, meta = result
|
||||
self.model_states["cnn"]["checkpoint_loaded"] = True
|
||||
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data
|
||||
@@ -7951,28 +7976,29 @@ class TradingOrchestrator:
|
||||
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
# Initialize model states dictionary to track performance
|
||||
self.model_states = {
|
||||
"dqn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cnn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cob_rl": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"extrema": {
|
||||
"initial_loss": None,
|
||||
# Initialize model states dictionary to track performance (only if not already initialized)
|
||||
if not hasattr(self, 'model_states') or self.model_states is None:
|
||||
self.model_states = {
|
||||
"dqn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cnn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cob_rl": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"extrema": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
|
88
debug_checkpoint_loading.py
Normal file
88
debug_checkpoint_loading.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to test checkpoint loading
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append('.')
|
||||
|
||||
def test_checkpoint_loading():
|
||||
"""Test checkpoint loading for all models"""
|
||||
print("=== Testing Checkpoint Loading ===")
|
||||
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
db = get_database_manager()
|
||||
|
||||
# Test models that should have checkpoints
|
||||
models = ['dqn_agent', 'enhanced_cnn', 'cob_rl_model', 'extrema_trainer']
|
||||
|
||||
for model in models:
|
||||
print(f"\n--- Testing {model} ---")
|
||||
|
||||
# Check database
|
||||
checkpoints = db.list_checkpoints(model)
|
||||
print(f"DB checkpoints: {len(checkpoints)}")
|
||||
|
||||
if checkpoints:
|
||||
best = db.get_best_checkpoint_metadata(model)
|
||||
if best:
|
||||
print(f"Best checkpoint: {best.checkpoint_id}")
|
||||
print(f"File path: {best.file_path}")
|
||||
print(f"File exists: {os.path.exists(best.file_path)}")
|
||||
print(f"Active: {best.is_active}")
|
||||
print(f"Performance metrics: {best.performance_metrics}")
|
||||
else:
|
||||
print("No best checkpoint found in DB")
|
||||
|
||||
# Test filesystem fallback
|
||||
result = load_best_checkpoint(model)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
print(f"Filesystem fallback: {file_path}")
|
||||
print(f"File exists: {os.path.exists(file_path)}")
|
||||
print(f"Metadata: {getattr(metadata, 'checkpoint_id', 'N/A')}")
|
||||
else:
|
||||
print("No filesystem fallback found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def test_model_initialization():
|
||||
"""Test model initialization with checkpoint loading"""
|
||||
print("\n=== Testing Model Initialization ===")
|
||||
|
||||
try:
|
||||
# Test DQN Agent initialization
|
||||
print("\n--- Testing DQN Agent ---")
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
# Create a minimal DQN agent
|
||||
agent = DQNAgent(
|
||||
state_shape=(10,), # Simple state shape
|
||||
n_actions=3,
|
||||
model_name="dqn_agent",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
|
||||
print(f"Agent created with checkpoints enabled: {agent.enable_checkpoints}")
|
||||
print(f"Model name: {agent.model_name}")
|
||||
|
||||
# Try to load checkpoint manually
|
||||
print("Attempting manual checkpoint load...")
|
||||
agent.load_best_checkpoint()
|
||||
print("Manual checkpoint load completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in model initialization: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_loading()
|
||||
test_model_initialization()
|
74
force_refresh_dashboard.py
Normal file
74
force_refresh_dashboard.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Force refresh dashboard model states to show correct DQN status
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
sys.path.append('.')
|
||||
|
||||
def force_refresh_dashboard():
|
||||
"""Force refresh the dashboard to show correct model states"""
|
||||
print("=== Forcing Dashboard Refresh ===")
|
||||
|
||||
# Try to hit the dashboard endpoint to force a refresh
|
||||
dashboard_url = "http://localhost:8050"
|
||||
|
||||
try:
|
||||
print(f"Attempting to refresh dashboard at {dashboard_url}...")
|
||||
|
||||
# Try to access the main dashboard page
|
||||
response = requests.get(dashboard_url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
print("✅ Dashboard is accessible and responding")
|
||||
else:
|
||||
print(f"⚠️ Dashboard responded with status code: {response.status_code}")
|
||||
|
||||
# Try to access the model states API endpoint if it exists
|
||||
try:
|
||||
api_response = requests.get(f"{dashboard_url}/api/model-states", timeout=5)
|
||||
if api_response.status_code == 200:
|
||||
print("✅ Model states API is accessible")
|
||||
else:
|
||||
print(f"⚠️ Model states API responded with: {api_response.status_code}")
|
||||
except:
|
||||
print("ℹ️ No model states API endpoint found (this is normal)")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ Dashboard is not running or not accessible")
|
||||
print("Please start the dashboard with: python run_clean_dashboard.py")
|
||||
except Exception as e:
|
||||
print(f"❌ Error accessing dashboard: {e}")
|
||||
|
||||
# Also verify the model states are correct
|
||||
print("\n=== Verifying Model States ===")
|
||||
try:
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
dp = DataProvider()
|
||||
orch = TradingOrchestrator(data_provider=dp)
|
||||
|
||||
states = orch.get_model_states()
|
||||
if states and 'dqn' in states:
|
||||
dqn_state = states['dqn']
|
||||
checkpoint_loaded = dqn_state.get('checkpoint_loaded', False)
|
||||
checkpoint_filename = dqn_state.get('checkpoint_filename', 'None')
|
||||
|
||||
print(f"✅ DQN checkpoint_loaded: {checkpoint_loaded}")
|
||||
print(f"✅ DQN checkpoint_filename: {checkpoint_filename}")
|
||||
|
||||
if checkpoint_loaded:
|
||||
print("🎯 DQN should show as ACTIVE with [CKPT], not FRESH")
|
||||
else:
|
||||
print("⚠️ DQN checkpoint not loaded - will show as FRESH")
|
||||
else:
|
||||
print("❌ No DQN state found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error verifying model states: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
force_refresh_dashboard()
|
@@ -26,8 +26,6 @@ class ModelsTrainingPanel:
|
||||
data = self._gather_data()
|
||||
|
||||
content: List[html.Div] = []
|
||||
content.append(self._create_header())
|
||||
|
||||
content.append(self._create_models_section(data.get("models", {})))
|
||||
|
||||
if data.get("training_status"):
|
||||
@@ -55,13 +53,21 @@ class ModelsTrainingPanel:
|
||||
if isinstance(stats, dict):
|
||||
stats_summary = stats
|
||||
|
||||
# Model states (for best_loss and checkpoint flags)
|
||||
model_states: Dict[str, Dict[str, Any]] = getattr(self.orchestrator, "model_states", {}) or {}
|
||||
# Model states (for best_loss and checkpoint flags) - use get_model_states() for updated checkpoint info
|
||||
model_states: Dict[str, Dict[str, Any]] = {}
|
||||
if hasattr(self.orchestrator, "get_model_states"):
|
||||
try:
|
||||
model_states = self.orchestrator.get_model_states() or {}
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting model states: {e}")
|
||||
model_states = getattr(self.orchestrator, "model_states", {}) or {}
|
||||
|
||||
# Build models block from stats_summary
|
||||
for model_name, s in stats_summary.items():
|
||||
model_info: Dict[str, Any] = {}
|
||||
try:
|
||||
# Use actual model name - no mapping needed
|
||||
display_name = model_name.upper()
|
||||
# Status: active if we have any inference info
|
||||
total_inf = int(s.get("total_inferences", 0) or 0)
|
||||
status = "active" if (total_inf > 0 or s.get("last_inference_time")) else "registered"
|
||||
@@ -84,22 +90,35 @@ class ModelsTrainingPanel:
|
||||
"last_training": s.get("last_training_time"),
|
||||
}
|
||||
|
||||
# Checkpoint flags from orchestrator.model_states if present
|
||||
# Checkpoint flags from orchestrator.model_states - map external names to internal keys
|
||||
ckpt_loaded = False
|
||||
ckpt_failed = False
|
||||
if model_states and model_name in model_states:
|
||||
ckpt_loaded = bool(model_states[model_name].get("checkpoint_loaded", False))
|
||||
ckpt_failed = bool(model_states[model_name].get("checkpoint_failed", False))
|
||||
checkpoint_filename = None
|
||||
|
||||
# Map external model names to internal model state keys
|
||||
state_key_map = {
|
||||
"dqn_agent": "dqn",
|
||||
"enhanced_cnn": "cnn",
|
||||
"extrema_trainer": "extrema_trainer",
|
||||
"decision": "decision",
|
||||
"cob_rl_model": "cob_rl"
|
||||
}
|
||||
state_key = state_key_map.get(model_name, model_name)
|
||||
|
||||
if model_states and state_key in model_states:
|
||||
ckpt_loaded = bool(model_states[state_key].get("checkpoint_loaded", False))
|
||||
ckpt_failed = bool(model_states[state_key].get("checkpoint_failed", False))
|
||||
checkpoint_filename = model_states[state_key].get("checkpoint_filename")
|
||||
|
||||
# If model started fresh, mark as stored after first successful autosave/best_loss update
|
||||
if not ckpt_loaded:
|
||||
cur = model_states[model_name].get("current_loss")
|
||||
best = model_states[model_name].get("best_loss")
|
||||
cur = model_states[state_key].get("current_loss")
|
||||
best = model_states[state_key].get("best_loss")
|
||||
if isinstance(cur, (int, float)) and isinstance(best, (int, float)) and best is not None and cur is not None and cur <= best:
|
||||
ckpt_loaded = True
|
||||
checkpoint_filename = model_states[model_name].get("checkpoint_filename")
|
||||
|
||||
model_info = {
|
||||
"name": model_name,
|
||||
"name": display_name,
|
||||
"status": status,
|
||||
"parameters": 0, # unknown; do not synthesize
|
||||
"last_prediction": {
|
||||
@@ -112,7 +131,7 @@ class ModelsTrainingPanel:
|
||||
"routing_enabled": True,
|
||||
"checkpoint_loaded": ckpt_loaded,
|
||||
"checkpoint_failed": ckpt_failed,
|
||||
"checkpoint_filename": checkpoint_filename if model_states and model_name in model_states else None,
|
||||
"checkpoint_filename": checkpoint_filename,
|
||||
"loss_metrics": loss_metrics,
|
||||
"timing_metrics": timing_metrics,
|
||||
"signal_stats": {},
|
||||
@@ -145,12 +164,6 @@ class ModelsTrainingPanel:
|
||||
logger.error(f"Error gathering panel data: {e}")
|
||||
return result
|
||||
|
||||
def _create_header(self) -> html.Div:
|
||||
return html.Div([
|
||||
html.H6([html.I(className="fas fa-brain me-2 text-primary"), "Models & Training Progress"], className="mb-2"),
|
||||
html.Button([html.I(className="fas fa-sync-alt me-1"), "Refresh"], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary mb-2")
|
||||
], className="d-flex justify-content-between align-items-start")
|
||||
|
||||
def _create_models_section(self, models: Dict[str, Any]) -> html.Div:
|
||||
cards = [self._create_model_card(name, info) for name, info in models.items()]
|
||||
return html.Div([
|
||||
|
Reference in New Issue
Block a user