improved model loading and training
This commit is contained in:
@@ -158,18 +158,19 @@ class AnnotationDashboard:
|
||||
if self.data_provider:
|
||||
self._enable_unified_storage_async()
|
||||
|
||||
# ANNOTATE doesn't need orchestrator immediately - load async for fast startup
|
||||
# ANNOTATE doesn't need orchestrator immediately - lazy load on demand
|
||||
self.orchestrator = None
|
||||
self.models_loading = True
|
||||
self.available_models = []
|
||||
self.models_loading = False
|
||||
self.available_models = ['DQN', 'CNN', 'Transformer'] # Models that CAN be loaded
|
||||
self.loaded_models = {} # Models that ARE loaded: {name: model_instance}
|
||||
|
||||
# Initialize ANNOTATE components
|
||||
self.annotation_manager = AnnotationManager()
|
||||
# Use REAL training adapter - NO SIMULATION!
|
||||
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
||||
|
||||
# Start async model loading in background
|
||||
self._start_async_model_loading()
|
||||
# Don't auto-load models - wait for user to click LOAD button
|
||||
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
|
||||
|
||||
# Initialize data loader with existing DataProvider
|
||||
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
|
||||
@@ -184,89 +185,93 @@ class AnnotationDashboard:
|
||||
|
||||
logger.info("Annotation Dashboard initialized")
|
||||
|
||||
def _start_async_model_loading(self):
|
||||
"""Load ML models asynchronously in background thread with retry logic"""
|
||||
import threading
|
||||
import time
|
||||
def _load_model_lazy(self, model_name: str) -> dict:
|
||||
"""
|
||||
Lazy load a specific model on demand
|
||||
|
||||
def load_models():
|
||||
max_retries = 3
|
||||
retry_delay = 5 # seconds
|
||||
Args:
|
||||
model_name: Name of model to load ('DQN', 'CNN', 'Transformer')
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
logger.info(f" Retry attempt {attempt + 1}/{max_retries} for model loading...")
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
logger.info(" Starting async model loading...")
|
||||
|
||||
# Check if TradingOrchestrator is available
|
||||
if not TradingOrchestrator:
|
||||
logger.error(" TradingOrchestrator class not available")
|
||||
self.models_loading = False
|
||||
self.available_models = []
|
||||
return
|
||||
|
||||
# Initialize orchestrator with models
|
||||
logger.info(" Creating TradingOrchestrator instance...")
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
logger.info(" Orchestrator created")
|
||||
|
||||
# Initialize ML models
|
||||
logger.info(" Initializing ML models...")
|
||||
self.orchestrator._initialize_ml_models()
|
||||
logger.info(" ML models initialized")
|
||||
|
||||
# Update training adapter with orchestrator
|
||||
self.training_adapter.orchestrator = self.orchestrator
|
||||
logger.info(" Training adapter updated")
|
||||
|
||||
# Get available models from orchestrator
|
||||
available = []
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
available.append('DQN')
|
||||
logger.info(" DQN model available")
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
available.append('CNN')
|
||||
logger.info(" CNN model available")
|
||||
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
||||
available.append('Transformer')
|
||||
logger.info(" Transformer model available")
|
||||
|
||||
self.available_models = available
|
||||
|
||||
if available:
|
||||
logger.info(f" Models loaded successfully: {', '.join(available)}")
|
||||
else:
|
||||
logger.warning(" No models were initialized (this might be normal if models aren't configured)")
|
||||
|
||||
self.models_loading = False
|
||||
logger.info(" Async model loading complete")
|
||||
return # Success - exit retry loop
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error loading models (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
# Final attempt failed
|
||||
logger.error(f" Model loading failed after {max_retries} attempts")
|
||||
self.models_loading = False
|
||||
self.available_models = []
|
||||
else:
|
||||
logger.info(f" Will retry in {retry_delay} seconds...")
|
||||
|
||||
# Start loading in background thread
|
||||
thread = threading.Thread(target=load_models, daemon=True, name="ModelLoader")
|
||||
thread.start()
|
||||
logger.info(f" Model loading started in background thread (ID: {thread.ident}, Name: {thread.name})")
|
||||
logger.info(" UI remains responsive while models load...")
|
||||
logger.info(" Will retry up to 3 times if loading fails")
|
||||
Returns:
|
||||
dict: Result with success status and message
|
||||
"""
|
||||
try:
|
||||
# Check if already loaded
|
||||
if model_name in self.loaded_models:
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'{model_name} already loaded',
|
||||
'already_loaded': True
|
||||
}
|
||||
|
||||
# Check if model is available
|
||||
if model_name not in self.available_models:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'{model_name} is not in available models list'
|
||||
}
|
||||
|
||||
logger.info(f"Loading {model_name} model...")
|
||||
|
||||
# Initialize orchestrator if not already done
|
||||
if not self.orchestrator:
|
||||
if not TradingOrchestrator:
|
||||
return {
|
||||
'success': False,
|
||||
'error': 'TradingOrchestrator class not available'
|
||||
}
|
||||
|
||||
logger.info("Creating TradingOrchestrator instance...")
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
logger.info("Orchestrator created")
|
||||
|
||||
# Update training adapter
|
||||
self.training_adapter.orchestrator = self.orchestrator
|
||||
|
||||
# Load specific model
|
||||
if model_name == 'DQN':
|
||||
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
# Initialize RL agent
|
||||
self.orchestrator._initialize_rl_agent()
|
||||
self.loaded_models['DQN'] = self.orchestrator.rl_agent
|
||||
|
||||
elif model_name == 'CNN':
|
||||
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
# Initialize CNN model
|
||||
self.orchestrator._initialize_cnn_model()
|
||||
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
||||
|
||||
elif model_name == 'Transformer':
|
||||
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||
# Initialize Transformer model
|
||||
self.orchestrator._initialize_transformer_model()
|
||||
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
|
||||
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'Unknown model: {model_name}'
|
||||
}
|
||||
|
||||
logger.info(f"{model_name} model loaded successfully")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'{model_name} loaded successfully',
|
||||
'loaded_models': list(self.loaded_models.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {model_name}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def _enable_unified_storage_async(self):
|
||||
"""Enable unified storage system in background thread"""
|
||||
@@ -1122,47 +1127,66 @@ class AnnotationDashboard:
|
||||
|
||||
@self.server.route('/api/available-models', methods=['GET'])
|
||||
def get_available_models():
|
||||
"""Get list of available models with loading status"""
|
||||
"""Get list of available models with their load status"""
|
||||
try:
|
||||
if not self.training_adapter:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'loading': False,
|
||||
'error': {
|
||||
'code': 'TRAINING_UNAVAILABLE',
|
||||
'message': 'Real training adapter not available'
|
||||
}
|
||||
})
|
||||
# Use self.available_models which is a simple list of strings
|
||||
# Don't call training_adapter.get_available_models() as it may return objects
|
||||
|
||||
# Check if models are still loading
|
||||
if self.models_loading:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'loading': True,
|
||||
'models': [],
|
||||
'message': 'Models are loading in background...'
|
||||
# Build model state dict
|
||||
model_states = []
|
||||
for model_name in self.available_models:
|
||||
is_loaded = model_name in self.loaded_models
|
||||
model_states.append({
|
||||
'name': model_name,
|
||||
'loaded': is_loaded,
|
||||
'can_train': is_loaded,
|
||||
'can_infer': is_loaded
|
||||
})
|
||||
|
||||
# Models loaded - get the list
|
||||
models = self.training_adapter.get_available_models()
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'loading': False,
|
||||
'models': models
|
||||
'models': model_states,
|
||||
'loaded_count': len(self.loaded_models),
|
||||
'available_count': len(self.available_models)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting available models: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'loading': False,
|
||||
'error': {
|
||||
'code': 'MODEL_LIST_ERROR',
|
||||
'message': str(e)
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/load-model', methods=['POST'])
|
||||
def load_model():
|
||||
"""Load a specific model on demand"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
model_name = data.get('model_name')
|
||||
|
||||
if not model_name:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'model_name is required'
|
||||
})
|
||||
|
||||
# Load the model
|
||||
result = self._load_model_lazy(model_name)
|
||||
|
||||
return jsonify(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in load_model endpoint: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@self.server.route('/api/realtime-inference/start', methods=['POST'])
|
||||
def start_realtime_inference():
|
||||
"""Start real-time inference mode"""
|
||||
|
||||
Reference in New Issue
Block a user