improved model loading and training

This commit is contained in:
Dobromir Popov
2025-10-31 01:22:49 +02:00
parent 7ddf98bf18
commit ba91740e4c
7 changed files with 745 additions and 186 deletions

View File

@@ -158,18 +158,19 @@ class AnnotationDashboard:
if self.data_provider:
self._enable_unified_storage_async()
# ANNOTATE doesn't need orchestrator immediately - load async for fast startup
# ANNOTATE doesn't need orchestrator immediately - lazy load on demand
self.orchestrator = None
self.models_loading = True
self.available_models = []
self.models_loading = False
self.available_models = ['DQN', 'CNN', 'Transformer'] # Models that CAN be loaded
self.loaded_models = {} # Models that ARE loaded: {name: model_instance}
# Initialize ANNOTATE components
self.annotation_manager = AnnotationManager()
# Use REAL training adapter - NO SIMULATION!
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
# Start async model loading in background
self._start_async_model_loading()
# Don't auto-load models - wait for user to click LOAD button
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
# Initialize data loader with existing DataProvider
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
@@ -184,89 +185,93 @@ class AnnotationDashboard:
logger.info("Annotation Dashboard initialized")
def _start_async_model_loading(self):
"""Load ML models asynchronously in background thread with retry logic"""
import threading
import time
def _load_model_lazy(self, model_name: str) -> dict:
"""
Lazy load a specific model on demand
def load_models():
max_retries = 3
retry_delay = 5 # seconds
Args:
model_name: Name of model to load ('DQN', 'CNN', 'Transformer')
for attempt in range(max_retries):
try:
if attempt > 0:
logger.info(f" Retry attempt {attempt + 1}/{max_retries} for model loading...")
time.sleep(retry_delay)
else:
logger.info(" Starting async model loading...")
# Check if TradingOrchestrator is available
if not TradingOrchestrator:
logger.error(" TradingOrchestrator class not available")
self.models_loading = False
self.available_models = []
return
# Initialize orchestrator with models
logger.info(" Creating TradingOrchestrator instance...")
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=True
)
logger.info(" Orchestrator created")
# Initialize ML models
logger.info(" Initializing ML models...")
self.orchestrator._initialize_ml_models()
logger.info(" ML models initialized")
# Update training adapter with orchestrator
self.training_adapter.orchestrator = self.orchestrator
logger.info(" Training adapter updated")
# Get available models from orchestrator
available = []
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
available.append('DQN')
logger.info(" DQN model available")
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
available.append('CNN')
logger.info(" CNN model available")
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
available.append('Transformer')
logger.info(" Transformer model available")
self.available_models = available
if available:
logger.info(f" Models loaded successfully: {', '.join(available)}")
else:
logger.warning(" No models were initialized (this might be normal if models aren't configured)")
self.models_loading = False
logger.info(" Async model loading complete")
return # Success - exit retry loop
except Exception as e:
logger.error(f" Error loading models (attempt {attempt + 1}/{max_retries}): {e}")
import traceback
logger.error(f"Traceback:\n{traceback.format_exc()}")
if attempt == max_retries - 1:
# Final attempt failed
logger.error(f" Model loading failed after {max_retries} attempts")
self.models_loading = False
self.available_models = []
else:
logger.info(f" Will retry in {retry_delay} seconds...")
# Start loading in background thread
thread = threading.Thread(target=load_models, daemon=True, name="ModelLoader")
thread.start()
logger.info(f" Model loading started in background thread (ID: {thread.ident}, Name: {thread.name})")
logger.info(" UI remains responsive while models load...")
logger.info(" Will retry up to 3 times if loading fails")
Returns:
dict: Result with success status and message
"""
try:
# Check if already loaded
if model_name in self.loaded_models:
return {
'success': True,
'message': f'{model_name} already loaded',
'already_loaded': True
}
# Check if model is available
if model_name not in self.available_models:
return {
'success': False,
'error': f'{model_name} is not in available models list'
}
logger.info(f"Loading {model_name} model...")
# Initialize orchestrator if not already done
if not self.orchestrator:
if not TradingOrchestrator:
return {
'success': False,
'error': 'TradingOrchestrator class not available'
}
logger.info("Creating TradingOrchestrator instance...")
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=True
)
logger.info("Orchestrator created")
# Update training adapter
self.training_adapter.orchestrator = self.orchestrator
# Load specific model
if model_name == 'DQN':
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
# Initialize RL agent
self.orchestrator._initialize_rl_agent()
self.loaded_models['DQN'] = self.orchestrator.rl_agent
elif model_name == 'CNN':
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
# Initialize CNN model
self.orchestrator._initialize_cnn_model()
self.loaded_models['CNN'] = self.orchestrator.cnn_model
elif model_name == 'Transformer':
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
# Initialize Transformer model
self.orchestrator._initialize_transformer_model()
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
else:
return {
'success': False,
'error': f'Unknown model: {model_name}'
}
logger.info(f"{model_name} model loaded successfully")
return {
'success': True,
'message': f'{model_name} loaded successfully',
'loaded_models': list(self.loaded_models.keys())
}
except Exception as e:
logger.error(f"Error loading {model_name}: {e}")
import traceback
logger.error(f"Traceback:\n{traceback.format_exc()}")
return {
'success': False,
'error': str(e)
}
def _enable_unified_storage_async(self):
"""Enable unified storage system in background thread"""
@@ -1122,47 +1127,66 @@ class AnnotationDashboard:
@self.server.route('/api/available-models', methods=['GET'])
def get_available_models():
"""Get list of available models with loading status"""
"""Get list of available models with their load status"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'loading': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
# Use self.available_models which is a simple list of strings
# Don't call training_adapter.get_available_models() as it may return objects
# Check if models are still loading
if self.models_loading:
return jsonify({
'success': True,
'loading': True,
'models': [],
'message': 'Models are loading in background...'
# Build model state dict
model_states = []
for model_name in self.available_models:
is_loaded = model_name in self.loaded_models
model_states.append({
'name': model_name,
'loaded': is_loaded,
'can_train': is_loaded,
'can_infer': is_loaded
})
# Models loaded - get the list
models = self.training_adapter.get_available_models()
return jsonify({
'success': True,
'loading': False,
'models': models
'models': model_states,
'loaded_count': len(self.loaded_models),
'available_count': len(self.available_models)
})
except Exception as e:
logger.error(f"Error getting available models: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
return jsonify({
'success': False,
'loading': False,
'error': {
'code': 'MODEL_LIST_ERROR',
'message': str(e)
}
})
@self.server.route('/api/load-model', methods=['POST'])
def load_model():
"""Load a specific model on demand"""
try:
data = request.get_json()
model_name = data.get('model_name')
if not model_name:
return jsonify({
'success': False,
'error': 'model_name is required'
})
# Load the model
result = self._load_model_lazy(model_name)
return jsonify(result)
except Exception as e:
logger.error(f"Error in load_model endpoint: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/realtime-inference/start', methods=['POST'])
def start_realtime_inference():
"""Start real-time inference mode"""

View File

@@ -16,10 +16,14 @@
<!-- Training Controls -->
<div class="mb-3">
<button class="btn btn-primary btn-sm w-100" id="train-model-btn">
<button class="btn btn-primary btn-sm w-100" id="train-model-btn" style="display: none;">
<i class="fas fa-play"></i>
Train Model
</button>
<button class="btn btn-success btn-sm w-100" id="load-model-btn" style="display: none;">
<i class="fas fa-download"></i>
Load Model
</button>
</div>
<!-- Training Status -->
@@ -102,71 +106,95 @@
</div>
<script>
// Load available models on page load with polling for async loading
let modelLoadingPollInterval = null;
// Track model states
let modelStates = [];
let selectedModel = null;
function loadAvailableModels() {
fetch('/api/available-models')
.then(response => response.json())
.then(data => {
console.log('📊 Available models API response:', JSON.stringify(data, null, 2));
const modelSelect = document.getElementById('model-select');
if (data.loading) {
// Models still loading - show loading message and poll
modelSelect.innerHTML = '<option value=""> Loading models...</option>';
// Start polling if not already polling
if (!modelLoadingPollInterval) {
console.log('Models loading in background, will poll for completion...');
modelLoadingPollInterval = setInterval(loadAvailableModels, 2000); // Poll every 2 seconds
}
} else {
// Models loaded - stop polling
if (modelLoadingPollInterval) {
clearInterval(modelLoadingPollInterval);
modelLoadingPollInterval = null;
}
if (data.success && data.models && Array.isArray(data.models)) {
modelStates = data.models;
modelSelect.innerHTML = '';
if (data.success && data.models.length > 0) {
// Show success notification
if (window.showSuccess) {
window.showSuccess(` ${data.models.length} models loaded and ready for training`);
}
data.models.forEach(model => {
const option = document.createElement('option');
option.value = model;
option.textContent = model;
modelSelect.appendChild(option);
});
console.log(` Models loaded: ${data.models.join(', ')}`);
} else {
// Add placeholder option
const placeholder = document.createElement('option');
placeholder.value = '';
placeholder.textContent = 'Select a model...';
modelSelect.appendChild(placeholder);
// Add model options with load status
data.models.forEach((model, index) => {
console.log(` Model ${index}:`, model, 'Type:', typeof model);
// Ensure model is an object with name property
const modelName = (model && typeof model === 'object' && model.name) ? model.name : String(model);
const isLoaded = (model && typeof model === 'object' && 'loaded' in model) ? model.loaded : false;
console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`);
const option = document.createElement('option');
option.value = '';
option.textContent = 'No models available';
option.value = modelName;
option.textContent = modelName + (isLoaded ? ' ✓' : ' (not loaded)');
option.dataset.loaded = isLoaded;
modelSelect.appendChild(option);
}
});
console.log(`✓ Models available: ${data.available_count}, loaded: ${data.loaded_count}`);
// Update button state for currently selected model
updateButtonState();
} else {
console.error('❌ Invalid response format:', data);
modelSelect.innerHTML = '<option value="">No models available</option>';
}
})
.catch(error => {
console.error('Error loading models:', error);
console.error('Error loading models:', error);
const modelSelect = document.getElementById('model-select');
// Don't stop polling on network errors - keep trying
if (!modelLoadingPollInterval) {
modelSelect.innerHTML = '<option value=""> Connection error, retrying...</option>';
// Start polling to retry
modelLoadingPollInterval = setInterval(loadAvailableModels, 3000); // Poll every 3 seconds
} else {
// Already polling, just update the message
modelSelect.innerHTML = '<option value=""> Retrying...</option>';
}
modelSelect.innerHTML = '<option value="">Error loading models</option>';
});
}
function updateButtonState() {
const modelSelect = document.getElementById('model-select');
const trainBtn = document.getElementById('train-model-btn');
const loadBtn = document.getElementById('load-model-btn');
const inferenceBtn = document.getElementById('start-inference-btn');
selectedModel = modelSelect.value;
if (!selectedModel) {
// No model selected
trainBtn.style.display = 'none';
loadBtn.style.display = 'none';
inferenceBtn.disabled = true;
return;
}
// Find model state
const modelState = modelStates.find(m => m.name === selectedModel);
if (modelState && modelState.loaded) {
// Model is loaded - show train/inference buttons
trainBtn.style.display = 'block';
loadBtn.style.display = 'none';
inferenceBtn.disabled = false;
} else {
// Model not loaded - show load button
trainBtn.style.display = 'none';
loadBtn.style.display = 'block';
inferenceBtn.disabled = true;
}
}
// Update button state when model selection changes
document.getElementById('model-select').addEventListener('change', updateButtonState);
// Load models when page loads
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', loadAvailableModels);
@@ -174,6 +202,45 @@
loadAvailableModels();
}
// Load model button handler
document.getElementById('load-model-btn').addEventListener('click', function () {
const modelName = document.getElementById('model-select').value;
if (!modelName) {
showError('Please select a model first');
return;
}
// Disable button and show loading
const loadBtn = this;
loadBtn.disabled = true;
loadBtn.innerHTML = '<span class="spinner-border spinner-border-sm me-1"></span>Loading...';
// Load the model
fetch('/api/load-model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_name: modelName })
})
.then(response => response.json())
.then(data => {
if (data.success) {
showSuccess(`${modelName} loaded successfully`);
// Refresh model list to update states
loadAvailableModels();
} else {
showError(`Failed to load ${modelName}: ${data.error}`);
loadBtn.disabled = false;
loadBtn.innerHTML = '<i class="fas fa-download"></i> Load Model';
}
})
.catch(error => {
showError('Network error: ' + error.message);
loadBtn.disabled = false;
loadBtn.innerHTML = '<i class="fas fa-download"></i> Load Model';
});
});
// Train model button
document.getElementById('train-model-btn').addEventListener('click', function () {
const modelName = document.getElementById('model-select').value;