fixes
This commit is contained in:
@@ -315,6 +315,16 @@ class TradingOrchestrator:
|
||||
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
|
||||
# Canonical → accepted aliases (internal/legacy)
|
||||
self.model_name_aliases: Dict[str, list] = {
|
||||
"DQN": ["dqn_agent", "dqn"],
|
||||
"CNN": ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"],
|
||||
"EXTREMA": ["extrema_trainer", "extrema"],
|
||||
"COB": ["cob_rl_model", "cob_rl"],
|
||||
"DECISION": ["decision_fusion", "decision"],
|
||||
}
|
||||
|
||||
# Configuration - AGGRESSIVE for more training data
|
||||
self.confidence_threshold = self.config.orchestrator.get(
|
||||
"confidence_threshold", 0.15
|
||||
@@ -555,17 +565,16 @@ class TradingOrchestrator:
|
||||
def _normalize_model_name(self, name: str) -> str:
|
||||
"""Map various registry/UI names to canonical toggle keys."""
|
||||
try:
|
||||
mapping = {
|
||||
"dqn_agent": "dqn",
|
||||
"enhanced_cnn": "cnn",
|
||||
"cnn_model": "cnn",
|
||||
"decision": "decision_fusion",
|
||||
"decision_fusion": "decision_fusion",
|
||||
"cob_rl_model": "cob_rl",
|
||||
"cob_rl": "cob_rl",
|
||||
"transformer_model": "transformer",
|
||||
# Use alias map to unify names to canonical keys
|
||||
alias_to_canonical = {
|
||||
**{alias: "DQN" for alias in ["dqn_agent", "dqn"]},
|
||||
**{alias: "CNN" for alias in ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"]},
|
||||
**{alias: "EXTREMA" for alias in ["extrema_trainer", "extrema"]},
|
||||
**{alias: "COB" for alias in ["cob_rl_model", "cob_rl"]},
|
||||
**{alias: "DECISION" for alias in ["decision_fusion", "decision"]},
|
||||
"transformer_model": "TRANSFORMER",
|
||||
}
|
||||
return mapping.get(name, name)
|
||||
return alias_to_canonical.get(name, name)
|
||||
except Exception:
|
||||
return name
|
||||
|
||||
@@ -643,41 +652,44 @@ class TradingOrchestrator:
|
||||
)
|
||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||||
# Check if we have checkpoints available using database metadata (fast!)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
||||
"dqn_agent"
|
||||
)
|
||||
self.rl_agent.load_best_checkpoint() # Load model state if available
|
||||
# 1) Try DB metadata first
|
||||
try:
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||
except Exception:
|
||||
checkpoint_metadata = None
|
||||
if checkpoint_metadata:
|
||||
self.model_states["dqn"]["initial_loss"] = 0.412
|
||||
self.model_states["dqn"]["current_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["dqn"]["best_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["dqn"]["current_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
self.model_states["dqn"]["best_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
self.model_states["dqn"]["checkpoint_loaded"] = True
|
||||
self.model_states["dqn"][
|
||||
"checkpoint_filename"
|
||||
] = checkpoint_metadata.checkpoint_id
|
||||
self.model_states["dqn"]["checkpoint_filename"] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(
|
||||
f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
else:
|
||||
# 2) Filesystem fallback via CheckpointManager
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
cm = get_checkpoint_manager()
|
||||
result = cm.load_best_checkpoint("dqn_agent")
|
||||
if result:
|
||||
model_path, meta = result
|
||||
# We already loaded model weights via load_best_checkpoint; just record metadata
|
||||
self.model_states["dqn"]["checkpoint_loaded"] = True
|
||||
self.model_states["dqn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"DQN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error loading DQN checkpoint (likely dimension mismatch): {e}"
|
||||
)
|
||||
logger.info(
|
||||
"DQN will start fresh due to checkpoint incompatibility"
|
||||
)
|
||||
# Reset the agent to handle dimension mismatch
|
||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||
logger.info("DQN will start fresh due to checkpoint incompatibility")
|
||||
checkpoint_loaded = False
|
||||
|
||||
if not checkpoint_loaded:
|
||||
@@ -712,7 +724,7 @@ class TradingOrchestrator:
|
||||
self.cnn_model.parameters(), lr=0.001
|
||||
) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
db_manager = get_database_manager()
|
||||
@@ -738,6 +750,19 @@ class TradingOrchestrator:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
# Filesystem fallback
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
cm = get_checkpoint_manager()
|
||||
result = cm.load_best_checkpoint("enhanced_cnn")
|
||||
if result:
|
||||
model_path, meta = result
|
||||
self.model_states["cnn"]["checkpoint_loaded"] = True
|
||||
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data
|
||||
@@ -7951,28 +7976,29 @@ class TradingOrchestrator:
|
||||
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
# Initialize model states dictionary to track performance
|
||||
self.model_states = {
|
||||
"dqn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cnn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cob_rl": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"extrema": {
|
||||
"initial_loss": None,
|
||||
# Initialize model states dictionary to track performance (only if not already initialized)
|
||||
if not hasattr(self, 'model_states') or self.model_states is None:
|
||||
self.model_states = {
|
||||
"dqn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cnn": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"cob_rl": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
},
|
||||
"extrema": {
|
||||
"initial_loss": None,
|
||||
"current_loss": None,
|
||||
"best_loss": float("inf"),
|
||||
"checkpoint_loaded": False,
|
||||
|
Reference in New Issue
Block a user