fix models loading /saving issue

This commit is contained in:
Dobromir Popov
2025-09-02 16:05:44 +03:00
parent 1b54438082
commit 15cc694669
13 changed files with 2264 additions and 72 deletions

View File

@@ -199,12 +199,13 @@ class TradingOrchestrator:
logger.info("Initializing ML models...")
# Initialize model state tracking (SSOT)
# Note: COB_RL functionality is now integrated into Enhanced CNN
self.model_states = {
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
'transformer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
}
# Initialize DQN Agent
@@ -282,7 +283,9 @@ class TradingOrchestrator:
self.model_states['cnn']['best_loss'] = None
logger.info("CNN starting fresh - no checkpoint found")
logger.info("Enhanced CNN model initialized")
logger.info("Enhanced CNN model initialized with integrated COB functionality")
logger.info(" - CNN handles both price patterns AND market microstructure (COB) analysis")
logger.info(" - Unified model eliminates redundancy and improves context integration")
except ImportError:
try:
from NN.models.cnn_model import CNNModel
@@ -338,48 +341,102 @@ class TradingOrchestrator:
logger.warning("Extrema trainer not available")
self.extrema_trainer = None
# Initialize COB RL Model
try:
from NN.models.cob_rl_model import COBRLModelInterface
self.cob_rl_agent = COBRLModelInterface()
# Load best checkpoint and capture initial state
checkpoint_loaded = False
if hasattr(self.cob_rl_agent, 'load_model'):
try:
self.cob_rl_agent.load_model() # This loads the state into the model
from utils.checkpoint_manager import load_best_checkpoint
# Use consistent model name with checkpoint manager and get_model_states
result = load_best_checkpoint("cob_rl")
if result:
file_path, metadata = result
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
self.model_states['cob_rl']['current_loss'] = metadata.loss
self.model_states['cob_rl']['best_loss'] = metadata.loss
self.model_states['cob_rl']['checkpoint_loaded'] = True
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
checkpoint_loaded = True
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
except Exception as e:
logger.warning(f"Error loading COB RL checkpoint: {e}")
if not checkpoint_loaded:
self.model_states['cob_rl']['initial_loss'] = None
self.model_states['cob_rl']['current_loss'] = None
self.model_states['cob_rl']['best_loss'] = None
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
logger.info("COB RL starting fresh - no checkpoint found")
logger.info("COB RL model initialized")
except ImportError:
logger.warning("COB RL model not available")
self.cob_rl_agent = None
# COB RL functionality is now integrated into the Enhanced CNN model
# The Enhanced CNN already receives COB data and has microstructure attention
# This eliminates redundancy and improves context integration
logger.info("COB RL functionality integrated into Enhanced CNN - no separate model needed")
self.cob_rl_agent = None # Deprecated in favor of Enhanced CNN integration
# Initialize Decision model state - no synthetic data
self.model_states['decision']['initial_loss'] = None
self.model_states['decision']['current_loss'] = None
self.model_states['decision']['best_loss'] = None
# Initialize TRANSFORMER Model
try:
from NN.models.advanced_transformer_trading import create_trading_transformer, TradingTransformerConfig
config = TradingTransformerConfig(
d_model=256, # 15M parameters target
n_heads=8,
n_layers=4,
seq_len=50,
n_actions=3,
use_multi_scale_attention=True,
use_market_regime_detection=True,
use_uncertainty_estimation=True
)
self.transformer_model, self.transformer_trainer = create_trading_transformer(config)
# Load best checkpoint
checkpoint_loaded = False
try:
from utils.checkpoint_manager import load_best_checkpoint
result = load_best_checkpoint("transformer")
if result:
file_path, metadata = result
self.transformer_trainer.load_model(file_path)
self.model_states['transformer']['checkpoint_loaded'] = True
self.model_states['transformer']['checkpoint_filename'] = metadata.checkpoint_id
checkpoint_loaded = True
logger.info(f"Transformer checkpoint loaded: {metadata.checkpoint_id}")
except Exception as e:
logger.debug(f"No transformer checkpoint found: {e}")
if not checkpoint_loaded:
self.model_states['transformer']['checkpoint_loaded'] = False
self.model_states['transformer']['checkpoint_filename'] = 'none (fresh start)'
logger.info("Transformer starting fresh - no checkpoint found")
logger.info("Transformer model initialized")
except ImportError as e:
logger.warning(f"Transformer model not available: {e}")
self.transformer_model = None
self.transformer_trainer = None
# Initialize Decision Fusion Model
try:
from core.nn_decision_fusion import NeuralDecisionFusion
# Initialize decision fusion (training_mode parameter only)
self.decision_model = NeuralDecisionFusion(training_mode=True)
# Load best checkpoint
checkpoint_loaded = False
try:
from utils.checkpoint_manager import load_best_checkpoint
result = load_best_checkpoint("decision")
if result:
file_path, metadata = result
import torch
checkpoint = torch.load(file_path, map_location='cpu')
if 'model_state_dict' in checkpoint:
self.decision_model.load_state_dict(checkpoint['model_state_dict'])
self.model_states['decision']['checkpoint_loaded'] = True
self.model_states['decision']['checkpoint_filename'] = metadata.checkpoint_id
checkpoint_loaded = True
logger.info(f"Decision model checkpoint loaded: {metadata.checkpoint_id}")
except Exception as e:
logger.debug(f"No decision model checkpoint found: {e}")
if not checkpoint_loaded:
self.model_states['decision']['checkpoint_loaded'] = False
self.model_states['decision']['checkpoint_filename'] = 'none (fresh start)'
logger.info("Decision model starting fresh - no checkpoint found")
logger.info("Decision fusion model initialized")
except ImportError as e:
logger.warning(f"Decision fusion model not available: {e}")
self.decision_model = None
# Initialize all model states with defaults for non-loaded models
for model_name in ['decision', 'transformer']:
if model_name not in self.model_states:
self.model_states[model_name] = {
'initial_loss': None,
'current_loss': None,
'best_loss': None,
'checkpoint_loaded': False,
'checkpoint_filename': 'none (fresh start)'
}
# CRITICAL: Register models with the model registry
logger.info("Registering models with model registry...")
@@ -431,20 +488,59 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Failed to register Extrema Trainer: {e}")
# Register COB RL Agent
if self.cob_rl_agent:
try:
cob_rl_interface = COBRLModelInterface(self.cob_rl_agent, name="cob_rl_model")
self.register_model(cob_rl_interface, weight=0.15)
logger.info("COB RL Agent registered successfully")
except Exception as e:
logger.error(f"Failed to register COB RL Agent: {e}")
# COB RL functionality is now integrated into Enhanced CNN
# No separate registration needed - COB analysis is part of CNN microstructure attention
logger.info("COB RL functionality integrated into Enhanced CNN - no separate registration needed")
# If decision model is initialized elsewhere, ensure it's registered too
# Register Transformer Model
if hasattr(self, 'transformer_model') and self.transformer_model:
try:
class TransformerModelInterface(ModelInterface):
def __init__(self, model, trainer, name: str):
super().__init__(name)
self.model = model
self.trainer = trainer
def predict(self, data):
try:
if hasattr(self.model, 'predict'):
return self.model.predict(data)
return None
except Exception as e:
logger.error(f"Error in transformer prediction: {e}")
return None
def get_memory_usage(self) -> float:
return 60.0 # MB estimate for transformer
transformer_interface = TransformerModelInterface(self.transformer_model, self.transformer_trainer, name="transformer")
self.register_model(transformer_interface, weight=0.2)
logger.info("Transformer Model registered successfully")
except Exception as e:
logger.error(f"Failed to register Transformer Model: {e}")
# Register Decision Fusion Model
if hasattr(self, 'decision_model') and self.decision_model:
try:
decision_interface = ModelInterface(self.decision_model, name="decision_fusion")
self.register_model(decision_interface, weight=0.2) # Weight for decision fusion
class DecisionModelInterface(ModelInterface):
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data):
try:
if hasattr(self.model, 'predict'):
return self.model.predict(data)
return None
except Exception as e:
logger.error(f"Error in decision model prediction: {e}")
return None
def get_memory_usage(self) -> float:
return 40.0 # MB estimate for decision model
decision_interface = DecisionModelInterface(self.decision_model, name="decision")
self.register_model(decision_interface, weight=0.15)
logger.info("Decision Fusion Model registered successfully")
except Exception as e:
logger.error(f"Failed to register Decision Fusion Model: {e}")
@@ -452,6 +548,7 @@ class TradingOrchestrator:
# Normalize weights after all registrations
self._normalize_weights()
logger.info(f"Current model weights: {self.model_weights}")
logger.info("COB_RL consolidation completed - Enhanced CNN now handles microstructure analysis")
except Exception as e:
logger.error(f"Error initializing ML models: {e}")
@@ -479,6 +576,45 @@ class TradingOrchestrator:
self.model_states[model_name]['best_loss'] = saved_loss
logger.info(f"New best loss for {model_name}: {saved_loss:.4f}")
def get_recent_predictions(self, limit: int = 10) -> List[Dict[str, Any]]:
"""Get recent predictions from all models for data streaming"""
try:
predictions = []
# Collect predictions from prediction history if available
if hasattr(self, 'prediction_history'):
for symbol, preds in self.prediction_history.items():
recent_preds = list(preds)[-limit:]
for pred in recent_preds:
predictions.append({
'timestamp': pred.get('timestamp', datetime.now().isoformat()),
'model_name': pred.get('model_name', 'unknown'),
'symbol': symbol,
'prediction': pred.get('prediction'),
'confidence': pred.get('confidence', 0),
'action': pred.get('action')
})
# Also collect from current model states
for model_name, state in self.model_states.items():
if 'last_prediction' in state:
predictions.append({
'timestamp': datetime.now().isoformat(),
'model_name': model_name,
'symbol': 'ETH/USDT', # Default symbol
'prediction': state['last_prediction'],
'confidence': state.get('last_confidence', 0),
'action': state.get('last_action')
})
# Sort by timestamp and return most recent
predictions.sort(key=lambda x: x['timestamp'], reverse=True)
return predictions[:limit]
except Exception as e:
logger.debug(f"Error getting recent predictions: {e}")
return []
def _save_orchestrator_state(self):
"""Save the current state of the orchestrator, including model states."""
state = {
@@ -1450,13 +1586,34 @@ class TradingOrchestrator:
def get_model_states(self) -> Dict[str, Dict]:
"""Get current model states with REAL checkpoint data - SSOT for dashboard"""
try:
# ENHANCED: Load actual checkpoint metadata for each model
# Cache checkpoint data to avoid repeated loading
if not hasattr(self, '_checkpoint_cache'):
self._checkpoint_cache = {}
self._checkpoint_cache_time = {}
# Only refresh checkpoint data every 60 seconds to avoid spam
import time
current_time = time.time()
cache_expiry = 60 # seconds
from utils.checkpoint_manager import load_best_checkpoint
# Update each model with REAL checkpoint data
for model_name in ['dqn_agent', 'enhanced_cnn', 'extrema_trainer', 'decision', 'cob_rl']:
# Update each model with REAL checkpoint data (cached)
# Note: COB_RL removed - functionality integrated into Enhanced CNN
for model_name in ['dqn_agent', 'enhanced_cnn', 'extrema_trainer', 'decision', 'transformer']:
try:
result = load_best_checkpoint(model_name)
# Check if we need to refresh cache for this model
needs_refresh = (
model_name not in self._checkpoint_cache or
current_time - self._checkpoint_cache_time.get(model_name, 0) > cache_expiry
)
if needs_refresh:
result = load_best_checkpoint(model_name)
self._checkpoint_cache[model_name] = result
self._checkpoint_cache_time[model_name] = current_time
result = self._checkpoint_cache[model_name]
if result:
file_path, metadata = result
@@ -1466,7 +1623,7 @@ class TradingOrchestrator:
'enhanced_cnn': 'cnn',
'extrema_trainer': 'extrema_trainer',
'decision': 'decision',
'cob_rl': 'cob_rl'
'transformer': 'transformer'
}.get(model_name, model_name)
if internal_key in self.model_states: