improved model loading and training
This commit is contained in:
@@ -158,18 +158,19 @@ class AnnotationDashboard:
|
||||
if self.data_provider:
|
||||
self._enable_unified_storage_async()
|
||||
|
||||
# ANNOTATE doesn't need orchestrator immediately - load async for fast startup
|
||||
# ANNOTATE doesn't need orchestrator immediately - lazy load on demand
|
||||
self.orchestrator = None
|
||||
self.models_loading = True
|
||||
self.available_models = []
|
||||
self.models_loading = False
|
||||
self.available_models = ['DQN', 'CNN', 'Transformer'] # Models that CAN be loaded
|
||||
self.loaded_models = {} # Models that ARE loaded: {name: model_instance}
|
||||
|
||||
# Initialize ANNOTATE components
|
||||
self.annotation_manager = AnnotationManager()
|
||||
# Use REAL training adapter - NO SIMULATION!
|
||||
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
||||
|
||||
# Start async model loading in background
|
||||
self._start_async_model_loading()
|
||||
# Don't auto-load models - wait for user to click LOAD button
|
||||
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
|
||||
|
||||
# Initialize data loader with existing DataProvider
|
||||
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
|
||||
@@ -184,89 +185,93 @@ class AnnotationDashboard:
|
||||
|
||||
logger.info("Annotation Dashboard initialized")
|
||||
|
||||
def _start_async_model_loading(self):
|
||||
"""Load ML models asynchronously in background thread with retry logic"""
|
||||
import threading
|
||||
import time
|
||||
def _load_model_lazy(self, model_name: str) -> dict:
|
||||
"""
|
||||
Lazy load a specific model on demand
|
||||
|
||||
def load_models():
|
||||
max_retries = 3
|
||||
retry_delay = 5 # seconds
|
||||
Args:
|
||||
model_name: Name of model to load ('DQN', 'CNN', 'Transformer')
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
logger.info(f" Retry attempt {attempt + 1}/{max_retries} for model loading...")
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
logger.info(" Starting async model loading...")
|
||||
|
||||
# Check if TradingOrchestrator is available
|
||||
if not TradingOrchestrator:
|
||||
logger.error(" TradingOrchestrator class not available")
|
||||
self.models_loading = False
|
||||
self.available_models = []
|
||||
return
|
||||
|
||||
# Initialize orchestrator with models
|
||||
logger.info(" Creating TradingOrchestrator instance...")
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
logger.info(" Orchestrator created")
|
||||
|
||||
# Initialize ML models
|
||||
logger.info(" Initializing ML models...")
|
||||
self.orchestrator._initialize_ml_models()
|
||||
logger.info(" ML models initialized")
|
||||
|
||||
# Update training adapter with orchestrator
|
||||
self.training_adapter.orchestrator = self.orchestrator
|
||||
logger.info(" Training adapter updated")
|
||||
|
||||
# Get available models from orchestrator
|
||||
available = []
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
available.append('DQN')
|
||||
logger.info(" DQN model available")
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
available.append('CNN')
|
||||
logger.info(" CNN model available")
|
||||
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
||||
available.append('Transformer')
|
||||
logger.info(" Transformer model available")
|
||||
|
||||
self.available_models = available
|
||||
|
||||
if available:
|
||||
logger.info(f" Models loaded successfully: {', '.join(available)}")
|
||||
else:
|
||||
logger.warning(" No models were initialized (this might be normal if models aren't configured)")
|
||||
|
||||
self.models_loading = False
|
||||
logger.info(" Async model loading complete")
|
||||
return # Success - exit retry loop
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error loading models (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
# Final attempt failed
|
||||
logger.error(f" Model loading failed after {max_retries} attempts")
|
||||
self.models_loading = False
|
||||
self.available_models = []
|
||||
else:
|
||||
logger.info(f" Will retry in {retry_delay} seconds...")
|
||||
|
||||
# Start loading in background thread
|
||||
thread = threading.Thread(target=load_models, daemon=True, name="ModelLoader")
|
||||
thread.start()
|
||||
logger.info(f" Model loading started in background thread (ID: {thread.ident}, Name: {thread.name})")
|
||||
logger.info(" UI remains responsive while models load...")
|
||||
logger.info(" Will retry up to 3 times if loading fails")
|
||||
Returns:
|
||||
dict: Result with success status and message
|
||||
"""
|
||||
try:
|
||||
# Check if already loaded
|
||||
if model_name in self.loaded_models:
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'{model_name} already loaded',
|
||||
'already_loaded': True
|
||||
}
|
||||
|
||||
# Check if model is available
|
||||
if model_name not in self.available_models:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'{model_name} is not in available models list'
|
||||
}
|
||||
|
||||
logger.info(f"Loading {model_name} model...")
|
||||
|
||||
# Initialize orchestrator if not already done
|
||||
if not self.orchestrator:
|
||||
if not TradingOrchestrator:
|
||||
return {
|
||||
'success': False,
|
||||
'error': 'TradingOrchestrator class not available'
|
||||
}
|
||||
|
||||
logger.info("Creating TradingOrchestrator instance...")
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
logger.info("Orchestrator created")
|
||||
|
||||
# Update training adapter
|
||||
self.training_adapter.orchestrator = self.orchestrator
|
||||
|
||||
# Load specific model
|
||||
if model_name == 'DQN':
|
||||
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
# Initialize RL agent
|
||||
self.orchestrator._initialize_rl_agent()
|
||||
self.loaded_models['DQN'] = self.orchestrator.rl_agent
|
||||
|
||||
elif model_name == 'CNN':
|
||||
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||
# Initialize CNN model
|
||||
self.orchestrator._initialize_cnn_model()
|
||||
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
||||
|
||||
elif model_name == 'Transformer':
|
||||
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||
# Initialize Transformer model
|
||||
self.orchestrator._initialize_transformer_model()
|
||||
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
|
||||
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'Unknown model: {model_name}'
|
||||
}
|
||||
|
||||
logger.info(f"{model_name} model loaded successfully")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'{model_name} loaded successfully',
|
||||
'loaded_models': list(self.loaded_models.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {model_name}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def _enable_unified_storage_async(self):
|
||||
"""Enable unified storage system in background thread"""
|
||||
@@ -1122,47 +1127,66 @@ class AnnotationDashboard:
|
||||
|
||||
@self.server.route('/api/available-models', methods=['GET'])
|
||||
def get_available_models():
|
||||
"""Get list of available models with loading status"""
|
||||
"""Get list of available models with their load status"""
|
||||
try:
|
||||
if not self.training_adapter:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'loading': False,
|
||||
'error': {
|
||||
'code': 'TRAINING_UNAVAILABLE',
|
||||
'message': 'Real training adapter not available'
|
||||
}
|
||||
})
|
||||
# Use self.available_models which is a simple list of strings
|
||||
# Don't call training_adapter.get_available_models() as it may return objects
|
||||
|
||||
# Check if models are still loading
|
||||
if self.models_loading:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'loading': True,
|
||||
'models': [],
|
||||
'message': 'Models are loading in background...'
|
||||
# Build model state dict
|
||||
model_states = []
|
||||
for model_name in self.available_models:
|
||||
is_loaded = model_name in self.loaded_models
|
||||
model_states.append({
|
||||
'name': model_name,
|
||||
'loaded': is_loaded,
|
||||
'can_train': is_loaded,
|
||||
'can_infer': is_loaded
|
||||
})
|
||||
|
||||
# Models loaded - get the list
|
||||
models = self.training_adapter.get_available_models()
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'loading': False,
|
||||
'models': models
|
||||
'models': model_states,
|
||||
'loaded_count': len(self.loaded_models),
|
||||
'available_count': len(self.available_models)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting available models: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'loading': False,
|
||||
'error': {
|
||||
'code': 'MODEL_LIST_ERROR',
|
||||
'message': str(e)
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/load-model', methods=['POST'])
|
||||
def load_model():
|
||||
"""Load a specific model on demand"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
model_name = data.get('model_name')
|
||||
|
||||
if not model_name:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'model_name is required'
|
||||
})
|
||||
|
||||
# Load the model
|
||||
result = self._load_model_lazy(model_name)
|
||||
|
||||
return jsonify(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in load_model endpoint: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@self.server.route('/api/realtime-inference/start', methods=['POST'])
|
||||
def start_realtime_inference():
|
||||
"""Start real-time inference mode"""
|
||||
|
||||
@@ -16,10 +16,14 @@
|
||||
|
||||
<!-- Training Controls -->
|
||||
<div class="mb-3">
|
||||
<button class="btn btn-primary btn-sm w-100" id="train-model-btn">
|
||||
<button class="btn btn-primary btn-sm w-100" id="train-model-btn" style="display: none;">
|
||||
<i class="fas fa-play"></i>
|
||||
Train Model
|
||||
</button>
|
||||
<button class="btn btn-success btn-sm w-100" id="load-model-btn" style="display: none;">
|
||||
<i class="fas fa-download"></i>
|
||||
Load Model
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Training Status -->
|
||||
@@ -102,71 +106,95 @@
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Load available models on page load with polling for async loading
|
||||
let modelLoadingPollInterval = null;
|
||||
|
||||
// Track model states
|
||||
let modelStates = [];
|
||||
let selectedModel = null;
|
||||
|
||||
function loadAvailableModels() {
|
||||
fetch('/api/available-models')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('📊 Available models API response:', JSON.stringify(data, null, 2));
|
||||
const modelSelect = document.getElementById('model-select');
|
||||
|
||||
if (data.loading) {
|
||||
// Models still loading - show loading message and poll
|
||||
modelSelect.innerHTML = '<option value=""> Loading models...</option>';
|
||||
|
||||
// Start polling if not already polling
|
||||
if (!modelLoadingPollInterval) {
|
||||
console.log('Models loading in background, will poll for completion...');
|
||||
modelLoadingPollInterval = setInterval(loadAvailableModels, 2000); // Poll every 2 seconds
|
||||
}
|
||||
} else {
|
||||
// Models loaded - stop polling
|
||||
if (modelLoadingPollInterval) {
|
||||
clearInterval(modelLoadingPollInterval);
|
||||
modelLoadingPollInterval = null;
|
||||
}
|
||||
|
||||
|
||||
if (data.success && data.models && Array.isArray(data.models)) {
|
||||
modelStates = data.models;
|
||||
modelSelect.innerHTML = '';
|
||||
|
||||
if (data.success && data.models.length > 0) {
|
||||
// Show success notification
|
||||
if (window.showSuccess) {
|
||||
window.showSuccess(` ${data.models.length} models loaded and ready for training`);
|
||||
}
|
||||
|
||||
data.models.forEach(model => {
|
||||
const option = document.createElement('option');
|
||||
option.value = model;
|
||||
option.textContent = model;
|
||||
modelSelect.appendChild(option);
|
||||
});
|
||||
|
||||
console.log(` Models loaded: ${data.models.join(', ')}`);
|
||||
} else {
|
||||
// Add placeholder option
|
||||
const placeholder = document.createElement('option');
|
||||
placeholder.value = '';
|
||||
placeholder.textContent = 'Select a model...';
|
||||
modelSelect.appendChild(placeholder);
|
||||
|
||||
// Add model options with load status
|
||||
data.models.forEach((model, index) => {
|
||||
console.log(` Model ${index}:`, model, 'Type:', typeof model);
|
||||
|
||||
// Ensure model is an object with name property
|
||||
const modelName = (model && typeof model === 'object' && model.name) ? model.name : String(model);
|
||||
const isLoaded = (model && typeof model === 'object' && 'loaded' in model) ? model.loaded : false;
|
||||
|
||||
console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`);
|
||||
|
||||
const option = document.createElement('option');
|
||||
option.value = '';
|
||||
option.textContent = 'No models available';
|
||||
option.value = modelName;
|
||||
option.textContent = modelName + (isLoaded ? ' ✓' : ' (not loaded)');
|
||||
option.dataset.loaded = isLoaded;
|
||||
modelSelect.appendChild(option);
|
||||
}
|
||||
});
|
||||
|
||||
console.log(`✓ Models available: ${data.available_count}, loaded: ${data.loaded_count}`);
|
||||
|
||||
// Update button state for currently selected model
|
||||
updateButtonState();
|
||||
} else {
|
||||
console.error('❌ Invalid response format:', data);
|
||||
modelSelect.innerHTML = '<option value="">No models available</option>';
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error loading models:', error);
|
||||
console.error('❌ Error loading models:', error);
|
||||
const modelSelect = document.getElementById('model-select');
|
||||
|
||||
// Don't stop polling on network errors - keep trying
|
||||
if (!modelLoadingPollInterval) {
|
||||
modelSelect.innerHTML = '<option value=""> Connection error, retrying...</option>';
|
||||
// Start polling to retry
|
||||
modelLoadingPollInterval = setInterval(loadAvailableModels, 3000); // Poll every 3 seconds
|
||||
} else {
|
||||
// Already polling, just update the message
|
||||
modelSelect.innerHTML = '<option value=""> Retrying...</option>';
|
||||
}
|
||||
modelSelect.innerHTML = '<option value="">Error loading models</option>';
|
||||
});
|
||||
}
|
||||
|
||||
function updateButtonState() {
|
||||
const modelSelect = document.getElementById('model-select');
|
||||
const trainBtn = document.getElementById('train-model-btn');
|
||||
const loadBtn = document.getElementById('load-model-btn');
|
||||
const inferenceBtn = document.getElementById('start-inference-btn');
|
||||
|
||||
selectedModel = modelSelect.value;
|
||||
|
||||
if (!selectedModel) {
|
||||
// No model selected
|
||||
trainBtn.style.display = 'none';
|
||||
loadBtn.style.display = 'none';
|
||||
inferenceBtn.disabled = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Find model state
|
||||
const modelState = modelStates.find(m => m.name === selectedModel);
|
||||
|
||||
if (modelState && modelState.loaded) {
|
||||
// Model is loaded - show train/inference buttons
|
||||
trainBtn.style.display = 'block';
|
||||
loadBtn.style.display = 'none';
|
||||
inferenceBtn.disabled = false;
|
||||
} else {
|
||||
// Model not loaded - show load button
|
||||
trainBtn.style.display = 'none';
|
||||
loadBtn.style.display = 'block';
|
||||
inferenceBtn.disabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Update button state when model selection changes
|
||||
document.getElementById('model-select').addEventListener('change', updateButtonState);
|
||||
|
||||
// Load models when page loads
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', loadAvailableModels);
|
||||
@@ -174,6 +202,45 @@
|
||||
loadAvailableModels();
|
||||
}
|
||||
|
||||
// Load model button handler
|
||||
document.getElementById('load-model-btn').addEventListener('click', function () {
|
||||
const modelName = document.getElementById('model-select').value;
|
||||
|
||||
if (!modelName) {
|
||||
showError('Please select a model first');
|
||||
return;
|
||||
}
|
||||
|
||||
// Disable button and show loading
|
||||
const loadBtn = this;
|
||||
loadBtn.disabled = true;
|
||||
loadBtn.innerHTML = '<span class="spinner-border spinner-border-sm me-1"></span>Loading...';
|
||||
|
||||
// Load the model
|
||||
fetch('/api/load-model', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_name: modelName })
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
showSuccess(`${modelName} loaded successfully`);
|
||||
// Refresh model list to update states
|
||||
loadAvailableModels();
|
||||
} else {
|
||||
showError(`Failed to load ${modelName}: ${data.error}`);
|
||||
loadBtn.disabled = false;
|
||||
loadBtn.innerHTML = '<i class="fas fa-download"></i> Load Model';
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
showError('Network error: ' + error.message);
|
||||
loadBtn.disabled = false;
|
||||
loadBtn.innerHTML = '<i class="fas fa-download"></i> Load Model';
|
||||
});
|
||||
});
|
||||
|
||||
// Train model button
|
||||
document.getElementById('train-model-btn').addEventListener('click', function () {
|
||||
const modelName = document.getElementById('model-select').value;
|
||||
|
||||
Reference in New Issue
Block a user