fix models loading /saving issue
This commit is contained in:
@@ -199,12 +199,13 @@ class TradingOrchestrator:
|
||||
logger.info("Initializing ML models...")
|
||||
|
||||
# Initialize model state tracking (SSOT)
|
||||
# Note: COB_RL functionality is now integrated into Enhanced CNN
|
||||
self.model_states = {
|
||||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'transformer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
# Initialize DQN Agent
|
||||
@@ -282,7 +283,9 @@ class TradingOrchestrator:
|
||||
self.model_states['cnn']['best_loss'] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
logger.info("Enhanced CNN model initialized with integrated COB functionality")
|
||||
logger.info(" - CNN handles both price patterns AND market microstructure (COB) analysis")
|
||||
logger.info(" - Unified model eliminates redundancy and improves context integration")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.cnn_model import CNNModel
|
||||
@@ -338,48 +341,102 @@ class TradingOrchestrator:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
# Initialize COB RL Model
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
self.cob_rl_agent = COBRLModelInterface()
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.cob_rl_agent, 'load_model'):
|
||||
try:
|
||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
# Use consistent model name with checkpoint manager and get_model_states
|
||||
result = load_best_checkpoint("cob_rl")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
self.model_states['cob_rl']['initial_loss'] = None
|
||||
self.model_states['cob_rl']['current_loss'] = None
|
||||
self.model_states['cob_rl']['best_loss'] = None
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("COB RL starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("COB RL model initialized")
|
||||
except ImportError:
|
||||
logger.warning("COB RL model not available")
|
||||
self.cob_rl_agent = None
|
||||
# COB RL functionality is now integrated into the Enhanced CNN model
|
||||
# The Enhanced CNN already receives COB data and has microstructure attention
|
||||
# This eliminates redundancy and improves context integration
|
||||
logger.info("COB RL functionality integrated into Enhanced CNN - no separate model needed")
|
||||
self.cob_rl_agent = None # Deprecated in favor of Enhanced CNN integration
|
||||
|
||||
# Initialize Decision model state - no synthetic data
|
||||
self.model_states['decision']['initial_loss'] = None
|
||||
self.model_states['decision']['current_loss'] = None
|
||||
self.model_states['decision']['best_loss'] = None
|
||||
# Initialize TRANSFORMER Model
|
||||
try:
|
||||
from NN.models.advanced_transformer_trading import create_trading_transformer, TradingTransformerConfig
|
||||
|
||||
config = TradingTransformerConfig(
|
||||
d_model=256, # 15M parameters target
|
||||
n_heads=8,
|
||||
n_layers=4,
|
||||
seq_len=50,
|
||||
n_actions=3,
|
||||
use_multi_scale_attention=True,
|
||||
use_market_regime_detection=True,
|
||||
use_uncertainty_estimation=True
|
||||
)
|
||||
|
||||
self.transformer_model, self.transformer_trainer = create_trading_transformer(config)
|
||||
|
||||
# Load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("transformer")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.transformer_trainer.load_model(file_path)
|
||||
self.model_states['transformer']['checkpoint_loaded'] = True
|
||||
self.model_states['transformer']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"Transformer checkpoint loaded: {metadata.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"No transformer checkpoint found: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
self.model_states['transformer']['checkpoint_loaded'] = False
|
||||
self.model_states['transformer']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("Transformer starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Transformer model initialized")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Transformer model not available: {e}")
|
||||
self.transformer_model = None
|
||||
self.transformer_trainer = None
|
||||
|
||||
# Initialize Decision Fusion Model
|
||||
try:
|
||||
from core.nn_decision_fusion import NeuralDecisionFusion
|
||||
|
||||
# Initialize decision fusion (training_mode parameter only)
|
||||
self.decision_model = NeuralDecisionFusion(training_mode=True)
|
||||
|
||||
# Load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("decision")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
import torch
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
if 'model_state_dict' in checkpoint:
|
||||
self.decision_model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model_states['decision']['checkpoint_loaded'] = True
|
||||
self.model_states['decision']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
logger.info(f"Decision model checkpoint loaded: {metadata.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"No decision model checkpoint found: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
self.model_states['decision']['checkpoint_loaded'] = False
|
||||
self.model_states['decision']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("Decision model starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Decision fusion model initialized")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Decision fusion model not available: {e}")
|
||||
self.decision_model = None
|
||||
|
||||
# Initialize all model states with defaults for non-loaded models
|
||||
for model_name in ['decision', 'transformer']:
|
||||
if model_name not in self.model_states:
|
||||
self.model_states[model_name] = {
|
||||
'initial_loss': None,
|
||||
'current_loss': None,
|
||||
'best_loss': None,
|
||||
'checkpoint_loaded': False,
|
||||
'checkpoint_filename': 'none (fresh start)'
|
||||
}
|
||||
|
||||
# CRITICAL: Register models with the model registry
|
||||
logger.info("Registering models with model registry...")
|
||||
@@ -431,20 +488,59 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||
|
||||
# Register COB RL Agent
|
||||
if self.cob_rl_agent:
|
||||
try:
|
||||
cob_rl_interface = COBRLModelInterface(self.cob_rl_agent, name="cob_rl_model")
|
||||
self.register_model(cob_rl_interface, weight=0.15)
|
||||
logger.info("COB RL Agent registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register COB RL Agent: {e}")
|
||||
# COB RL functionality is now integrated into Enhanced CNN
|
||||
# No separate registration needed - COB analysis is part of CNN microstructure attention
|
||||
logger.info("COB RL functionality integrated into Enhanced CNN - no separate registration needed")
|
||||
|
||||
# If decision model is initialized elsewhere, ensure it's registered too
|
||||
# Register Transformer Model
|
||||
if hasattr(self, 'transformer_model') and self.transformer_model:
|
||||
try:
|
||||
class TransformerModelInterface(ModelInterface):
|
||||
def __init__(self, model, trainer, name: str):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
self.trainer = trainer
|
||||
|
||||
def predict(self, data):
|
||||
try:
|
||||
if hasattr(self.model, 'predict'):
|
||||
return self.model.predict(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transformer prediction: {e}")
|
||||
return None
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
return 60.0 # MB estimate for transformer
|
||||
|
||||
transformer_interface = TransformerModelInterface(self.transformer_model, self.transformer_trainer, name="transformer")
|
||||
self.register_model(transformer_interface, weight=0.2)
|
||||
logger.info("Transformer Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Transformer Model: {e}")
|
||||
|
||||
# Register Decision Fusion Model
|
||||
if hasattr(self, 'decision_model') and self.decision_model:
|
||||
try:
|
||||
decision_interface = ModelInterface(self.decision_model, name="decision_fusion")
|
||||
self.register_model(decision_interface, weight=0.2) # Weight for decision fusion
|
||||
class DecisionModelInterface(ModelInterface):
|
||||
def __init__(self, model, name: str):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
|
||||
def predict(self, data):
|
||||
try:
|
||||
if hasattr(self.model, 'predict'):
|
||||
return self.model.predict(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in decision model prediction: {e}")
|
||||
return None
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
return 40.0 # MB estimate for decision model
|
||||
|
||||
decision_interface = DecisionModelInterface(self.decision_model, name="decision")
|
||||
self.register_model(decision_interface, weight=0.15)
|
||||
logger.info("Decision Fusion Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Decision Fusion Model: {e}")
|
||||
@@ -452,6 +548,7 @@ class TradingOrchestrator:
|
||||
# Normalize weights after all registrations
|
||||
self._normalize_weights()
|
||||
logger.info(f"Current model weights: {self.model_weights}")
|
||||
logger.info("COB_RL consolidation completed - Enhanced CNN now handles microstructure analysis")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
@@ -479,6 +576,45 @@ class TradingOrchestrator:
|
||||
self.model_states[model_name]['best_loss'] = saved_loss
|
||||
logger.info(f"New best loss for {model_name}: {saved_loss:.4f}")
|
||||
|
||||
def get_recent_predictions(self, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Get recent predictions from all models for data streaming"""
|
||||
try:
|
||||
predictions = []
|
||||
|
||||
# Collect predictions from prediction history if available
|
||||
if hasattr(self, 'prediction_history'):
|
||||
for symbol, preds in self.prediction_history.items():
|
||||
recent_preds = list(preds)[-limit:]
|
||||
for pred in recent_preds:
|
||||
predictions.append({
|
||||
'timestamp': pred.get('timestamp', datetime.now().isoformat()),
|
||||
'model_name': pred.get('model_name', 'unknown'),
|
||||
'symbol': symbol,
|
||||
'prediction': pred.get('prediction'),
|
||||
'confidence': pred.get('confidence', 0),
|
||||
'action': pred.get('action')
|
||||
})
|
||||
|
||||
# Also collect from current model states
|
||||
for model_name, state in self.model_states.items():
|
||||
if 'last_prediction' in state:
|
||||
predictions.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model_name': model_name,
|
||||
'symbol': 'ETH/USDT', # Default symbol
|
||||
'prediction': state['last_prediction'],
|
||||
'confidence': state.get('last_confidence', 0),
|
||||
'action': state.get('last_action')
|
||||
})
|
||||
|
||||
# Sort by timestamp and return most recent
|
||||
predictions.sort(key=lambda x: x['timestamp'], reverse=True)
|
||||
return predictions[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting recent predictions: {e}")
|
||||
return []
|
||||
|
||||
def _save_orchestrator_state(self):
|
||||
"""Save the current state of the orchestrator, including model states."""
|
||||
state = {
|
||||
@@ -1450,13 +1586,34 @@ class TradingOrchestrator:
|
||||
def get_model_states(self) -> Dict[str, Dict]:
|
||||
"""Get current model states with REAL checkpoint data - SSOT for dashboard"""
|
||||
try:
|
||||
# ENHANCED: Load actual checkpoint metadata for each model
|
||||
# Cache checkpoint data to avoid repeated loading
|
||||
if not hasattr(self, '_checkpoint_cache'):
|
||||
self._checkpoint_cache = {}
|
||||
self._checkpoint_cache_time = {}
|
||||
|
||||
# Only refresh checkpoint data every 60 seconds to avoid spam
|
||||
import time
|
||||
current_time = time.time()
|
||||
cache_expiry = 60 # seconds
|
||||
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
# Update each model with REAL checkpoint data
|
||||
for model_name in ['dqn_agent', 'enhanced_cnn', 'extrema_trainer', 'decision', 'cob_rl']:
|
||||
# Update each model with REAL checkpoint data (cached)
|
||||
# Note: COB_RL removed - functionality integrated into Enhanced CNN
|
||||
for model_name in ['dqn_agent', 'enhanced_cnn', 'extrema_trainer', 'decision', 'transformer']:
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
# Check if we need to refresh cache for this model
|
||||
needs_refresh = (
|
||||
model_name not in self._checkpoint_cache or
|
||||
current_time - self._checkpoint_cache_time.get(model_name, 0) > cache_expiry
|
||||
)
|
||||
|
||||
if needs_refresh:
|
||||
result = load_best_checkpoint(model_name)
|
||||
self._checkpoint_cache[model_name] = result
|
||||
self._checkpoint_cache_time[model_name] = current_time
|
||||
|
||||
result = self._checkpoint_cache[model_name]
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
|
||||
@@ -1466,7 +1623,7 @@ class TradingOrchestrator:
|
||||
'enhanced_cnn': 'cnn',
|
||||
'extrema_trainer': 'extrema_trainer',
|
||||
'decision': 'decision',
|
||||
'cob_rl': 'cob_rl'
|
||||
'transformer': 'transformer'
|
||||
}.get(model_name, model_name)
|
||||
|
||||
if internal_key in self.model_states:
|
||||
|
Reference in New Issue
Block a user