pivot points option in UI
This commit is contained in:
@@ -584,7 +584,6 @@ class TradingOrchestrator:
|
||||
return alias_to_canonical.get(name, name)
|
||||
except Exception:
|
||||
return name
|
||||
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
try:
|
||||
@@ -738,45 +737,42 @@ class TradingOrchestrator:
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
||||
"enhanced_cnn"
|
||||
)
|
||||
if checkpoint_metadata:
|
||||
self.model_states["cnn"]["initial_loss"] = 0.412
|
||||
self.model_states["cnn"]["current_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0187)
|
||||
)
|
||||
self.model_states["cnn"]["best_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0134)
|
||||
)
|
||||
self.model_states["cnn"]["checkpoint_loaded"] = True
|
||||
self.model_states["cnn"][
|
||||
"checkpoint_filename"
|
||||
] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(
|
||||
f"CNN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
if checkpoint_metadata and os.path.exists(checkpoint_metadata.file_path):
|
||||
try:
|
||||
saved = torch.load(checkpoint_metadata.file_path, map_location=self.device)
|
||||
if saved and saved.get("model_state_dict"):
|
||||
self.cnn_model.load_state_dict(saved["model_state_dict"], strict=False)
|
||||
checkpoint_loaded = True
|
||||
except Exception as load_ex:
|
||||
logger.warning(f"CNN checkpoint load_state_dict failed: {load_ex}")
|
||||
if not checkpoint_loaded:
|
||||
# Filesystem fallback
|
||||
from utils.checkpoint_manager import load_best_checkpoint as _load_best_ckpt
|
||||
result = _load_best_ckpt("enhanced_cnn")
|
||||
if result:
|
||||
ckpt_path, meta = result
|
||||
try:
|
||||
saved = torch.load(ckpt_path, map_location=self.device)
|
||||
if saved and saved.get("model_state_dict"):
|
||||
self.cnn_model.load_state_dict(saved["model_state_dict"], strict=False)
|
||||
checkpoint_loaded = True
|
||||
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, "checkpoint_id", os.path.basename(ckpt_path))
|
||||
except Exception as e_load:
|
||||
logger.warning(f"Failed loading CNN weights from {ckpt_path}: {e_load}")
|
||||
# Update model_states flags after attempts
|
||||
self.model_states["cnn"]["checkpoint_loaded"] = checkpoint_loaded
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
# Filesystem fallback
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
cm = get_checkpoint_manager()
|
||||
result = cm.load_best_checkpoint("enhanced_cnn")
|
||||
if result:
|
||||
model_path, meta = result
|
||||
self.model_states["cnn"]["checkpoint_loaded"] = True
|
||||
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
checkpoint_loaded = False
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data
|
||||
self.model_states["cnn"]["initial_loss"] = None
|
||||
self.model_states["cnn"]["current_loss"] = None
|
||||
self.model_states["cnn"]["best_loss"] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
self.model_states["cnn"]["checkpoint_loaded"] = False
|
||||
logger.info("CNN starting fresh - no checkpoint found or failed to load")
|
||||
else:
|
||||
logger.info("CNN weights loaded from checkpoint successfully")
|
||||
|
||||
logger.info("Enhanced CNN model initialized directly")
|
||||
except ImportError:
|
||||
@@ -1339,7 +1335,6 @@ class TradingOrchestrator:
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def clear_session_data(self):
|
||||
"""Clear all session-related data for fresh start"""
|
||||
try:
|
||||
@@ -2122,7 +2117,6 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering model {model.name}: {e}")
|
||||
return False
|
||||
|
||||
def unregister_model(self, model_name: str) -> bool:
|
||||
"""Unregister a model"""
|
||||
try:
|
||||
@@ -3540,7 +3534,6 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in immediate training for {model_name}: {e}")
|
||||
|
||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||
"""Evaluate prediction outcome and train model"""
|
||||
try:
|
||||
@@ -5779,7 +5772,6 @@ class TradingOrchestrator:
|
||||
if symbol in self.recent_decisions:
|
||||
return self.recent_decisions[symbol][-limit:]
|
||||
return []
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics for the orchestrator"""
|
||||
return {
|
||||
@@ -6579,7 +6571,6 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding decision fusion training sample: {e}")
|
||||
|
||||
def _train_decision_fusion_network(self):
|
||||
"""Train the decision fusion network on collected data"""
|
||||
try:
|
||||
@@ -8133,7 +8124,6 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None
|
||||
|
||||
def autosave_models(self):
|
||||
"""Attempt to autosave best model checkpoints periodically."""
|
||||
try:
|
||||
@@ -8990,4 +8980,4 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback data strategy: {e}")
|
||||
return False
|
||||
return False
|
Reference in New Issue
Block a user