276 lines
10 KiB
Python
276 lines
10 KiB
Python
"""
|
|
CNN Dashboard Integration
|
|
|
|
This module integrates the EnhancedCNN model with the dashboard, providing real-time
|
|
training and visualization of model predictions.
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
import os
|
|
import json
|
|
|
|
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
|
from utils.training_integration import get_training_integration
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CNNDashboardIntegration:
|
|
"""
|
|
Integrates the EnhancedCNN model with the dashboard
|
|
|
|
This class:
|
|
1. Loads and initializes the CNN model
|
|
2. Processes real-time data for model inference
|
|
3. Manages continuous training of the model
|
|
4. Provides visualization data for the dashboard
|
|
"""
|
|
|
|
def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"):
|
|
"""
|
|
Initialize the CNN dashboard integration
|
|
|
|
Args:
|
|
data_provider: Data provider instance
|
|
checkpoint_dir: Directory to save checkpoints to
|
|
"""
|
|
self.data_provider = data_provider
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.cnn_adapter = None
|
|
self.training_thread = None
|
|
self.training_active = False
|
|
self.training_interval = 60 # Train every 60 seconds
|
|
self.training_samples = []
|
|
self.max_training_samples = 1000
|
|
self.last_training_time = 0
|
|
self.last_predictions = {}
|
|
self.performance_metrics = {}
|
|
self.model_name = "enhanced_cnn_v1"
|
|
|
|
# Create checkpoint directory if it doesn't exist
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# Initialize CNN adapter
|
|
self._initialize_cnn_adapter()
|
|
|
|
logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}")
|
|
|
|
def _initialize_cnn_adapter(self):
|
|
"""Initialize the CNN adapter"""
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
|
|
# Create CNN adapter
|
|
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir)
|
|
|
|
# Load best checkpoint if available
|
|
self.cnn_adapter.load_best_checkpoint()
|
|
|
|
logger.info("CNN adapter initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing CNN adapter: {e}")
|
|
self.cnn_adapter = None
|
|
|
|
def start_training_thread(self):
|
|
"""Start the training thread"""
|
|
if self.training_thread is not None and self.training_thread.is_alive():
|
|
logger.info("Training thread already running")
|
|
return
|
|
|
|
self.training_active = True
|
|
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
|
|
self.training_thread.start()
|
|
|
|
logger.info("CNN training thread started")
|
|
|
|
def stop_training_thread(self):
|
|
"""Stop the training thread"""
|
|
self.training_active = False
|
|
if self.training_thread is not None:
|
|
self.training_thread.join(timeout=5)
|
|
self.training_thread = None
|
|
|
|
logger.info("CNN training thread stopped")
|
|
|
|
def _training_loop(self):
|
|
"""Training loop for continuous model training"""
|
|
while self.training_active:
|
|
try:
|
|
# Check if it's time to train
|
|
current_time = time.time()
|
|
if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10:
|
|
logger.info(f"Training CNN model with {len(self.training_samples)} samples")
|
|
|
|
# Train model
|
|
if self.cnn_adapter is not None:
|
|
metrics = self.cnn_adapter.train(epochs=1)
|
|
|
|
# Update performance metrics
|
|
self.performance_metrics = {
|
|
'loss': metrics.get('loss', 0.0),
|
|
'accuracy': metrics.get('accuracy', 0.0),
|
|
'samples': metrics.get('samples', 0),
|
|
'last_training': datetime.now().isoformat()
|
|
}
|
|
|
|
# Log training metrics
|
|
logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
|
|
|
|
# Update last training time
|
|
self.last_training_time = current_time
|
|
|
|
# Sleep to avoid high CPU usage
|
|
time.sleep(1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training loop: {e}")
|
|
time.sleep(5) # Sleep longer on error
|
|
|
|
def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]:
|
|
"""
|
|
Process data for model inference and training
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
base_data: Standardized input data
|
|
|
|
Returns:
|
|
Optional[ModelOutput]: Model output, or None if processing failed
|
|
"""
|
|
try:
|
|
if self.cnn_adapter is None:
|
|
logger.warning("CNN adapter not initialized")
|
|
return None
|
|
|
|
# Make prediction
|
|
model_output = self.cnn_adapter.predict(base_data)
|
|
|
|
# Store prediction
|
|
self.last_predictions[symbol] = model_output
|
|
|
|
# Store model output in data provider
|
|
if self.data_provider is not None:
|
|
self.data_provider.store_model_output(model_output)
|
|
|
|
return model_output
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing data for CNN model: {e}")
|
|
return None
|
|
|
|
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
|
|
"""
|
|
Add a training sample
|
|
|
|
Args:
|
|
base_data: Standardized input data
|
|
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
|
reward: Reward received for the action
|
|
"""
|
|
try:
|
|
if self.cnn_adapter is None:
|
|
logger.warning("CNN adapter not initialized")
|
|
return
|
|
|
|
# Add training sample to CNN adapter
|
|
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
|
|
|
|
# Add to local training samples
|
|
self.training_samples.append((base_data.symbol, actual_action, reward))
|
|
|
|
# Limit training samples
|
|
if len(self.training_samples) > self.max_training_samples:
|
|
self.training_samples = self.training_samples[-self.max_training_samples:]
|
|
|
|
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding training sample: {e}")
|
|
|
|
def get_performance_metrics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get performance metrics
|
|
|
|
Returns:
|
|
Dict[str, Any]: Performance metrics
|
|
"""
|
|
metrics = self.performance_metrics.copy()
|
|
|
|
# Add additional metrics
|
|
metrics['training_samples'] = len(self.training_samples)
|
|
metrics['model_name'] = self.model_name
|
|
|
|
# Add last prediction metrics
|
|
if self.last_predictions:
|
|
for symbol, prediction in self.last_predictions.items():
|
|
metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN')
|
|
metrics[f'{symbol}_last_confidence'] = prediction.confidence
|
|
|
|
return metrics
|
|
|
|
def get_visualization_data(self, symbol: str) -> Dict[str, Any]:
|
|
"""
|
|
Get visualization data for the dashboard
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
|
|
Returns:
|
|
Dict[str, Any]: Visualization data
|
|
"""
|
|
data = {
|
|
'model_name': self.model_name,
|
|
'symbol': symbol,
|
|
'timestamp': datetime.now().isoformat(),
|
|
'performance_metrics': self.get_performance_metrics()
|
|
}
|
|
|
|
# Add last prediction
|
|
if symbol in self.last_predictions:
|
|
prediction = self.last_predictions[symbol]
|
|
data['last_prediction'] = {
|
|
'action': prediction.predictions.get('action', 'UNKNOWN'),
|
|
'confidence': prediction.confidence,
|
|
'timestamp': prediction.timestamp.isoformat(),
|
|
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
|
|
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
|
|
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
|
|
}
|
|
|
|
# Add training samples summary
|
|
symbol_samples = [s for s in self.training_samples if s[0] == symbol]
|
|
data['training_samples'] = {
|
|
'total': len(symbol_samples),
|
|
'buy': len([s for s in symbol_samples if s[1] == 'BUY']),
|
|
'sell': len([s for s in symbol_samples if s[1] == 'SELL']),
|
|
'hold': len([s for s in symbol_samples if s[1] == 'HOLD']),
|
|
'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0
|
|
}
|
|
|
|
return data
|
|
|
|
# Global CNN dashboard integration instance
|
|
_cnn_dashboard_integration = None
|
|
|
|
def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration:
|
|
"""
|
|
Get the global CNN dashboard integration instance
|
|
|
|
Args:
|
|
data_provider: Data provider instance
|
|
|
|
Returns:
|
|
CNNDashboardIntegration: Global CNN dashboard integration instance
|
|
"""
|
|
global _cnn_dashboard_integration
|
|
|
|
if _cnn_dashboard_integration is None:
|
|
_cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider)
|
|
|
|
return _cnn_dashboard_integration |