model auto load
This commit is contained in:
@@ -589,8 +589,14 @@ class AnnotationDashboard:
|
||||
# Backtest runner for replaying visible chart with predictions
|
||||
self.backtest_runner = BacktestRunner()
|
||||
|
||||
# Don't auto-load models - wait for user to click LOAD button
|
||||
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
|
||||
# Check if we should auto-load a model at startup
|
||||
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
|
||||
|
||||
if auto_load_model and auto_load_model.lower() != 'none':
|
||||
logger.info(f"Auto-loading model: {auto_load_model}")
|
||||
self._auto_load_model(auto_load_model)
|
||||
else:
|
||||
logger.info("Auto-load disabled. Models available for lazy loading: " + ", ".join(self.available_models))
|
||||
|
||||
# Initialize data loader with existing DataProvider
|
||||
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
|
||||
@@ -605,6 +611,73 @@ class AnnotationDashboard:
|
||||
|
||||
logger.info("Annotation Dashboard initialized")
|
||||
|
||||
def _auto_load_model(self, model_name: str):
|
||||
"""
|
||||
Auto-load a model at startup in background thread
|
||||
|
||||
Args:
|
||||
model_name: Name of model to load (DQN, CNN, or Transformer)
|
||||
"""
|
||||
def load_in_background():
|
||||
try:
|
||||
logger.info(f"Starting auto-load for {model_name}...")
|
||||
|
||||
# Initialize orchestrator if not already done
|
||||
if not self.orchestrator:
|
||||
logger.info("Initializing TradingOrchestrator...")
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
config=self.config
|
||||
)
|
||||
self.training_adapter.orchestrator = self.orchestrator
|
||||
logger.info("TradingOrchestrator initialized")
|
||||
|
||||
# Load the specific model
|
||||
if model_name == 'Transformer':
|
||||
logger.info("Loading Transformer model...")
|
||||
self.orchestrator.load_transformer_model()
|
||||
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer_trainer
|
||||
logger.info("Transformer model loaded successfully")
|
||||
|
||||
elif model_name == 'CNN':
|
||||
logger.info("Loading CNN model...")
|
||||
self.orchestrator.load_cnn_model()
|
||||
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
||||
logger.info("CNN model loaded successfully")
|
||||
|
||||
elif model_name == 'DQN':
|
||||
logger.info("Loading DQN model...")
|
||||
self.orchestrator.load_dqn_model()
|
||||
self.loaded_models['DQN'] = self.orchestrator.dqn_agent
|
||||
logger.info("DQN model loaded successfully")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown model name: {model_name}")
|
||||
return
|
||||
|
||||
# Get checkpoint info for display
|
||||
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||
if checkpoint_info:
|
||||
logger.info(f" Checkpoint: {checkpoint_info.get('filename', 'N/A')}")
|
||||
if checkpoint_info.get('accuracy'):
|
||||
logger.info(f" Accuracy: {checkpoint_info['accuracy']:.2%}")
|
||||
if checkpoint_info.get('loss'):
|
||||
logger.info(f" Loss: {checkpoint_info['loss']:.4f}")
|
||||
|
||||
self.models_loading = False
|
||||
logger.info(f"{model_name} model ready for inference and training")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error auto-loading {model_name} model: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
self.models_loading = False
|
||||
|
||||
# Start loading in background thread
|
||||
self.models_loading = True
|
||||
thread = threading.Thread(target=load_in_background, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get best checkpoint info for a model without loading it
|
||||
|
||||
Reference in New Issue
Block a user