model auto load

This commit is contained in:
Dobromir Popov
2025-11-22 12:49:57 +02:00
parent e404658dc7
commit bccac9614d
3 changed files with 334 additions and 2 deletions

View File

@@ -589,8 +589,14 @@ class AnnotationDashboard:
# Backtest runner for replaying visible chart with predictions
self.backtest_runner = BacktestRunner()
# Don't auto-load models - wait for user to click LOAD button
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
# Check if we should auto-load a model at startup
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
if auto_load_model and auto_load_model.lower() != 'none':
logger.info(f"Auto-loading model: {auto_load_model}")
self._auto_load_model(auto_load_model)
else:
logger.info("Auto-load disabled. Models available for lazy loading: " + ", ".join(self.available_models))
# Initialize data loader with existing DataProvider
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
@@ -605,6 +611,73 @@ class AnnotationDashboard:
logger.info("Annotation Dashboard initialized")
def _auto_load_model(self, model_name: str):
"""
Auto-load a model at startup in background thread
Args:
model_name: Name of model to load (DQN, CNN, or Transformer)
"""
def load_in_background():
try:
logger.info(f"Starting auto-load for {model_name}...")
# Initialize orchestrator if not already done
if not self.orchestrator:
logger.info("Initializing TradingOrchestrator...")
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
config=self.config
)
self.training_adapter.orchestrator = self.orchestrator
logger.info("TradingOrchestrator initialized")
# Load the specific model
if model_name == 'Transformer':
logger.info("Loading Transformer model...")
self.orchestrator.load_transformer_model()
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer_trainer
logger.info("Transformer model loaded successfully")
elif model_name == 'CNN':
logger.info("Loading CNN model...")
self.orchestrator.load_cnn_model()
self.loaded_models['CNN'] = self.orchestrator.cnn_model
logger.info("CNN model loaded successfully")
elif model_name == 'DQN':
logger.info("Loading DQN model...")
self.orchestrator.load_dqn_model()
self.loaded_models['DQN'] = self.orchestrator.dqn_agent
logger.info("DQN model loaded successfully")
else:
logger.warning(f"Unknown model name: {model_name}")
return
# Get checkpoint info for display
checkpoint_info = self._get_best_checkpoint_info(model_name)
if checkpoint_info:
logger.info(f" Checkpoint: {checkpoint_info.get('filename', 'N/A')}")
if checkpoint_info.get('accuracy'):
logger.info(f" Accuracy: {checkpoint_info['accuracy']:.2%}")
if checkpoint_info.get('loss'):
logger.info(f" Loss: {checkpoint_info['loss']:.4f}")
self.models_loading = False
logger.info(f"{model_name} model ready for inference and training")
except Exception as e:
logger.error(f"Error auto-loading {model_name} model: {e}")
import traceback
logger.error(traceback.format_exc())
self.models_loading = False
# Start loading in background thread
self.models_loading = True
thread = threading.Thread(target=load_in_background, daemon=True)
thread.start()
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
"""
Get best checkpoint info for a model without loading it