This commit is contained in:
Dobromir Popov
2025-08-10 21:28:49 +03:00
parent 7289366a35
commit 8738f02d24
5 changed files with 311 additions and 78 deletions

View File

@@ -315,6 +315,16 @@ class TradingOrchestrator:
logger.info(f"Using device: {self.device}")
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
# Canonical → accepted aliases (internal/legacy)
self.model_name_aliases: Dict[str, list] = {
"DQN": ["dqn_agent", "dqn"],
"CNN": ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"],
"EXTREMA": ["extrema_trainer", "extrema"],
"COB": ["cob_rl_model", "cob_rl"],
"DECISION": ["decision_fusion", "decision"],
}
# Configuration - AGGRESSIVE for more training data
self.confidence_threshold = self.config.orchestrator.get(
"confidence_threshold", 0.15
@@ -555,17 +565,16 @@ class TradingOrchestrator:
def _normalize_model_name(self, name: str) -> str:
"""Map various registry/UI names to canonical toggle keys."""
try:
mapping = {
"dqn_agent": "dqn",
"enhanced_cnn": "cnn",
"cnn_model": "cnn",
"decision": "decision_fusion",
"decision_fusion": "decision_fusion",
"cob_rl_model": "cob_rl",
"cob_rl": "cob_rl",
"transformer_model": "transformer",
# Use alias map to unify names to canonical keys
alias_to_canonical = {
**{alias: "DQN" for alias in ["dqn_agent", "dqn"]},
**{alias: "CNN" for alias in ["enhanced_cnn", "cnn", "cnn_model", "standardized_cnn"]},
**{alias: "EXTREMA" for alias in ["extrema_trainer", "extrema"]},
**{alias: "COB" for alias in ["cob_rl_model", "cob_rl"]},
**{alias: "DECISION" for alias in ["decision_fusion", "decision"]},
"transformer_model": "TRANSFORMER",
}
return mapping.get(name, name)
return alias_to_canonical.get(name, name)
except Exception:
return name
@@ -643,41 +652,44 @@ class TradingOrchestrator:
)
self.rl_agent.to(self.device) # Move DQN agent to the determined device
# Load best checkpoint and capture initial state (using database metadata)
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
checkpoint_loaded = False
if hasattr(self.rl_agent, "load_best_checkpoint"):
try:
self.rl_agent.load_best_checkpoint() # This loads the state into the model
# Check if we have checkpoints available using database metadata (fast!)
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
"dqn_agent"
)
self.rl_agent.load_best_checkpoint() # Load model state if available
# 1) Try DB metadata first
try:
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
except Exception:
checkpoint_metadata = None
if checkpoint_metadata:
self.model_states["dqn"]["initial_loss"] = 0.412
self.model_states["dqn"]["current_loss"] = (
checkpoint_metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["dqn"]["best_loss"] = (
checkpoint_metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["dqn"]["current_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
self.model_states["dqn"]["best_loss"] = checkpoint_metadata.performance_metrics.get("loss", 0.0)
self.model_states["dqn"]["checkpoint_loaded"] = True
self.model_states["dqn"][
"checkpoint_filename"
] = checkpoint_metadata.checkpoint_id
self.model_states["dqn"]["checkpoint_filename"] = checkpoint_metadata.checkpoint_id
checkpoint_loaded = True
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(
f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
)
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
else:
# 2) Filesystem fallback via CheckpointManager
try:
from utils.checkpoint_manager import get_checkpoint_manager
cm = get_checkpoint_manager()
result = cm.load_best_checkpoint("dqn_agent")
if result:
model_path, meta = result
# We already loaded model weights via load_best_checkpoint; just record metadata
self.model_states["dqn"]["checkpoint_loaded"] = True
self.model_states["dqn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
checkpoint_loaded = True
logger.info(f"DQN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
except Exception:
pass
except Exception as e:
logger.warning(
f"Error loading DQN checkpoint (likely dimension mismatch): {e}"
)
logger.info(
"DQN will start fresh due to checkpoint incompatibility"
)
# Reset the agent to handle dimension mismatch
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
logger.info("DQN will start fresh due to checkpoint incompatibility")
checkpoint_loaded = False
if not checkpoint_loaded:
@@ -712,7 +724,7 @@ class TradingOrchestrator:
self.cnn_model.parameters(), lr=0.001
) # Initialize optimizer for CNN
# Load best checkpoint and capture initial state (using database metadata)
# Load best checkpoint and capture initial state (using database metadata or filesystem fallback)
checkpoint_loaded = False
try:
db_manager = get_database_manager()
@@ -738,6 +750,19 @@ class TradingOrchestrator:
)
except Exception as e:
logger.warning(f"Error loading CNN checkpoint: {e}")
# Filesystem fallback
try:
from utils.checkpoint_manager import get_checkpoint_manager
cm = get_checkpoint_manager()
result = cm.load_best_checkpoint("enhanced_cnn")
if result:
model_path, meta = result
self.model_states["cnn"]["checkpoint_loaded"] = True
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
checkpoint_loaded = True
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
except Exception:
pass
if not checkpoint_loaded:
# New model - no synthetic data
@@ -7951,28 +7976,29 @@ class TradingOrchestrator:
self.checkpoint_manager = get_checkpoint_manager()
# Initialize model states dictionary to track performance
self.model_states = {
"dqn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cnn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cob_rl": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"extrema": {
"initial_loss": None,
# Initialize model states dictionary to track performance (only if not already initialized)
if not hasattr(self, 'model_states') or self.model_states is None:
self.model_states = {
"dqn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cnn": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"cob_rl": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,
},
"extrema": {
"initial_loss": None,
"current_loss": None,
"best_loss": float("inf"),
"checkpoint_loaded": False,