From 94ee7389c4f0038c981820e1a935145e68b7c91a Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 23 Jul 2025 22:39:00 +0300 Subject: [PATCH] CNN training first working --- core/enhanced_cnn_adapter.py | 62 +++++- web/clean_dashboard.py | 402 ++++++++++++++++++++++++++++++++++- 2 files changed, 456 insertions(+), 8 deletions(-) diff --git a/core/enhanced_cnn_adapter.py b/core/enhanced_cnn_adapter.py index 586ab38..8a66b7a 100644 --- a/core/enhanced_cnn_adapter.py +++ b/core/enhanced_cnn_adapter.py @@ -51,10 +51,68 @@ class EnhancedCNNAdapter: # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) - # Initialize model + # Initialize the model self._initialize_model() - logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}") + # Load checkpoint if available + if model_path and os.path.exists(model_path): + self._load_checkpoint(model_path) + else: + self._load_best_checkpoint() + + logger.info(f"EnhancedCNNAdapter initialized on {self.device}") + + def _load_checkpoint(self, checkpoint_path: str) -> bool: + """Load model from checkpoint path""" + try: + if self.model and os.path.exists(checkpoint_path): + success = self.model.load(checkpoint_path) + if success: + logger.info(f"Loaded model from {checkpoint_path}") + return True + else: + logger.warning(f"Failed to load model from {checkpoint_path}") + return False + else: + logger.warning(f"Checkpoint path does not exist: {checkpoint_path}") + return False + except Exception as e: + logger.error(f"Error loading checkpoint: {e}") + return False + + def _load_best_checkpoint(self) -> bool: + """Load the best available checkpoint""" + try: + return self.load_best_checkpoint() + except Exception as e: + logger.error(f"Error loading best checkpoint: {e}") + return False + + + + def _create_default_output(self, symbol: str) -> ModelOutput: + """Create default output when prediction fails""" + return create_model_output( + model_type='cnn', + model_name=self.model_name, + symbol=symbol, + action='HOLD', + confidence=0.0, + metadata={'error': 'Prediction failed, using default output'} + ) + + def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]: + """Process hidden states for cross-model feeding""" + processed_states = {} + + for key, value in hidden_states.items(): + if isinstance(value, torch.Tensor): + # Convert tensor to numpy array + processed_states[key] = value.cpu().numpy().tolist() + else: + processed_states[key] = value + + return processed_states def _initialize_model(self): """Initialize the EnhancedCNN model""" diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index 89c7c61..2dc74fd 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -259,6 +259,10 @@ class CleanTradingDashboard: self.data_provider.start_cob_collection() logger.info("Started COB collection in data provider") + # Start CNN real-time prediction loop + self._start_cnn_prediction_loop() + logger.info("Started CNN real-time prediction loop") + # Then subscribe to updates self.data_provider.subscribe_to_cob(self._on_cob_data_update) logger.info("Subscribed to COB data updates from data provider") @@ -2718,6 +2722,82 @@ class CleanTradingDashboard: logger.debug(f"Error getting enhanced training stats: {e}") return {} + def _update_cnn_model_panel(self) -> Dict[str, Any]: + """Update CNN model panel with real-time data and performance metrics""" + try: + if not self.cnn_adapter: + return { + 'status': 'NOT_AVAILABLE', + 'parameters': '0M', + 'current_loss': 0.0, + 'accuracy': 0.0, + 'confidence': 0.0, + 'last_prediction': 'N/A', + 'training_samples': 0, + 'inference_rate': '0.00/s' + } + + # Get CNN prediction for ETH/USDT + prediction = self._get_cnn_prediction('ETH/USDT') + + # Get model performance metrics + model_info = self.cnn_adapter.get_model_info() if hasattr(self.cnn_adapter, 'get_model_info') else {} + + # Calculate inference rate + inference_times = getattr(self.cnn_adapter, 'inference_times', []) + if len(inference_times) > 0: + avg_inference_time = sum(inference_times[-10:]) / min(len(inference_times), 10) + inference_rate = f"{1.0/avg_inference_time:.2f}/s" if avg_inference_time > 0 else "0.00/s" + else: + inference_rate = "0.00/s" + + # Get training data count + training_samples = len(getattr(self.cnn_adapter, 'training_data', [])) + + # Format last prediction + if prediction: + last_prediction = f"{prediction['action']} ({prediction['confidence']:.1%})" + current_confidence = prediction['confidence'] + else: + last_prediction = "No prediction" + current_confidence = 0.0 + + # Get model status + if hasattr(self.cnn_adapter, 'model') and self.cnn_adapter.model: + if training_samples > 100: + status = 'TRAINED' + elif training_samples > 0: + status = 'TRAINING' + else: + status = 'FRESH' + else: + status = 'NOT_LOADED' + + return { + 'status': status, + 'parameters': model_info.get('parameters', '50.0M'), + 'current_loss': model_info.get('current_loss', 0.0), + 'accuracy': model_info.get('accuracy', 0.0), + 'confidence': current_confidence, + 'last_prediction': last_prediction, + 'training_samples': training_samples, + 'inference_rate': inference_rate, + 'last_update': datetime.now().strftime('%H:%M:%S') + } + + except Exception as e: + logger.error(f"Error updating CNN model panel: {e}") + return { + 'status': 'ERROR', + 'parameters': '0M', + 'current_loss': 0.0, + 'accuracy': 0.0, + 'confidence': 0.0, + 'last_prediction': f'Error: {str(e)}', + 'training_samples': 0, + 'inference_rate': '0.00/s' + } + def _get_training_metrics(self) -> Dict: """Get training metrics from unified orchestrator - using orchestrator as SSOT""" try: @@ -2751,6 +2831,19 @@ class CleanTradingDashboard: latest_predictions = self._get_latest_model_predictions() cnn_prediction = self._get_cnn_pivot_prediction() + # Get enhanced CNN model panel data + cnn_panel_data = self._update_cnn_model_panel() + + # Update CNN model in loaded_models with real-time data + if cnn_panel_data: + model_states['cnn'].update({ + 'status': cnn_panel_data.get('status', 'FRESH'), + 'confidence': cnn_panel_data.get('confidence', 0.0), + 'last_prediction': cnn_panel_data.get('last_prediction', 'No prediction'), + 'training_samples': cnn_panel_data.get('training_samples', 0), + 'inference_rate': cnn_panel_data.get('inference_rate', '0.00/s') + }) + # Get enhanced training statistics if available enhanced_training_stats = self._get_enhanced_training_stats() @@ -5534,14 +5627,311 @@ class CleanTradingDashboard: self.training_system = None def _initialize_standardized_cnn(self): - """Initialize StandardizedCNN model for the dashboard""" + """Initialize Enhanced CNN model with standardized input format for the dashboard""" try: - from NN.models.standardized_cnn import StandardizedCNN - self.standardized_cnn = StandardizedCNN(model_name="dashboard_standardized_cnn") - logger.info("StandardizedCNN model initialized for dashboard") + from core.enhanced_cnn_adapter import EnhancedCNNAdapter + + # Initialize the enhanced CNN adapter + self.cnn_adapter = EnhancedCNNAdapter( + checkpoint_dir="models/enhanced_cnn" + ) + + # For backward compatibility + self.standardized_cnn = self.cnn_adapter + + logger.info("Enhanced CNN adapter initialized for dashboard with standardized input format") + except Exception as e: - logger.warning(f"StandardizedCNN initialization failed: {e}") - self.standardized_cnn = None + logger.warning(f"Enhanced CNN adapter initialization failed: {e}") + + # Fallback to original StandardizedCNN + try: + from NN.models.standardized_cnn import StandardizedCNN + self.standardized_cnn = StandardizedCNN(model_name="dashboard_standardized_cnn") + self.cnn_adapter = None + logger.info("Fallback to StandardizedCNN model initialized for dashboard") + except Exception as e2: + logger.warning(f"StandardizedCNN fallback initialization failed: {e2}") + self.standardized_cnn = None + self.cnn_adapter = None + + def _get_cnn_prediction(self, symbol: str = 'ETH/USDT') -> Optional[Dict[str, Any]]: + """Get CNN prediction using standardized input format""" + try: + if not self.cnn_adapter: + return None + + # Get standardized input data from data provider + base_data_input = self._get_base_data_input(symbol) + if not base_data_input: + logger.debug(f"No base data input available for {symbol}") + return None + + # Make prediction using CNN adapter + model_output = self.cnn_adapter.predict(base_data_input) + + # Convert to dictionary for dashboard use + prediction = { + 'action': model_output.predictions.get('action', 'HOLD'), + 'confidence': model_output.confidence, + 'buy_probability': model_output.predictions.get('buy_probability', 0.0), + 'sell_probability': model_output.predictions.get('sell_probability', 0.0), + 'hold_probability': model_output.predictions.get('hold_probability', 0.0), + 'timestamp': model_output.timestamp, + 'hidden_states': model_output.hidden_states, + 'metadata': model_output.metadata + } + + logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})") + return prediction + + except Exception as e: + logger.error(f"Error getting CNN prediction: {e}") + return None + + def _get_base_data_input(self, symbol: str = 'ETH/USDT') -> Optional['BaseDataInput']: + """Get standardized BaseDataInput from data provider""" + try: + # Check if data provider supports standardized input + if hasattr(self.data_provider, 'get_base_data_input'): + return self.data_provider.get_base_data_input(symbol) + + # Fallback: create BaseDataInput from available data + from core.data_models import BaseDataInput, OHLCVBar, COBData + + # Get OHLCV data for different timeframes + ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300) + ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300) + ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300) + ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300) + + # Get BTC reference data + btc_ohlcv_1s = self._get_ohlcv_bars('BTC/USDT', '1s', 300) + + # Get COB data if available + cob_data = self._get_cob_data(symbol) + + # Create BaseDataInput + base_data_input = BaseDataInput( + symbol=symbol, + timestamp=datetime.now(), + ohlcv_1s=ohlcv_1s, + ohlcv_1m=ohlcv_1m, + ohlcv_1h=ohlcv_1h, + ohlcv_1d=ohlcv_1d, + btc_ohlcv_1s=btc_ohlcv_1s, + cob_data=cob_data, + technical_indicators=self._get_technical_indicators(symbol), + pivot_points=self._get_pivot_points(symbol), + last_predictions={} # TODO: Add cross-model predictions + ) + + return base_data_input + + except Exception as e: + logger.error(f"Error creating base data input: {e}") + return None + + def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List['OHLCVBar']: + """Get OHLCV bars from data provider""" + try: + from core.data_models import OHLCVBar + + # Get data from data provider + df = self.data_provider.get_candles(symbol, timeframe) + if df is None or len(df) == 0: + return [] + + # Convert to OHLCVBar objects + bars = [] + for idx, row in df.tail(count).iterrows(): + bar = OHLCVBar( + symbol=symbol, + timestamp=idx if isinstance(idx, datetime) else datetime.now(), + open=float(row['open']), + high=float(row['high']), + low=float(row['low']), + close=float(row['close']), + volume=float(row['volume']), + timeframe=timeframe, + indicators={} # TODO: Add technical indicators + ) + bars.append(bar) + + return bars + + except Exception as e: + logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}") + return [] + + def _get_cob_data(self, symbol: str) -> Optional['COBData']: + """Get COB data from latest cache""" + try: + if not hasattr(self, 'latest_cob_data') or symbol not in self.latest_cob_data: + return None + + from core.data_models import COBData + + cob_raw = self.latest_cob_data[symbol] + if not isinstance(cob_raw, dict) or 'stats' not in cob_raw: + return None + + stats = cob_raw['stats'] + current_price = stats.get('mid_price', 0.0) + + # Create price buckets (simplified for now) + bucket_size = 1.0 if 'ETH' in symbol else 10.0 + price_buckets = {} + + # Create ±20 buckets around current price + for i in range(-20, 21): + price = current_price + (i * bucket_size) + price_buckets[price] = { + 'bid_volume': 0.0, + 'ask_volume': 0.0, + 'total_volume': 0.0, + 'imbalance': stats.get('imbalance', 0.0) + } + + cob_data = COBData( + symbol=symbol, + timestamp=cob_raw.get('timestamp', datetime.now()), + current_price=current_price, + bucket_size=bucket_size, + price_buckets=price_buckets, + bid_ask_imbalance={current_price: stats.get('imbalance', 0.0)}, + volume_weighted_prices={current_price: current_price}, + order_flow_metrics=stats, + ma_1s_imbalance={current_price: stats.get('imbalance', 0.0)}, + ma_5s_imbalance={current_price: stats.get('imbalance_5s', 0.0)}, + ma_15s_imbalance={current_price: stats.get('imbalance_15s', 0.0)}, + ma_60s_imbalance={current_price: stats.get('imbalance_60s', 0.0)} + ) + + return cob_data + + except Exception as e: + logger.error(f"Error creating COB data for {symbol}: {e}") + return None + + def _get_technical_indicators(self, symbol: str) -> Dict[str, float]: + """Get technical indicators for symbol""" + try: + # TODO: Implement technical indicators calculation + return {} + except Exception as e: + logger.error(f"Error getting technical indicators for {symbol}: {e}") + return {} + + def _get_pivot_points(self, symbol: str) -> List['PivotPoint']: + """Get pivot points for symbol""" + try: + # TODO: Implement pivot points calculation + return [] + except Exception as e: + logger.error(f"Error getting pivot points for {symbol}: {e}") + return [] + + def _start_cnn_prediction_loop(self): + """Start CNN real-time prediction loop""" + try: + if not self.cnn_adapter: + logger.warning("CNN adapter not available, skipping prediction loop") + return + + def cnn_prediction_worker(): + """Worker thread for CNN predictions""" + logger.info("CNN prediction worker started") + + while True: + try: + # Make predictions for primary symbols + for symbol in ['ETH/USDT', 'BTC/USDT']: + prediction = self._get_cnn_prediction(symbol) + + if prediction: + # Store prediction for dashboard display + if not hasattr(self, 'cnn_predictions'): + self.cnn_predictions = {} + + self.cnn_predictions[symbol] = prediction + + # Add to training data if confidence is high enough + if prediction['confidence'] > 0.7: + self._add_cnn_training_sample(symbol, prediction) + + logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})") + + # Sleep for 1 second (1Hz prediction rate) + time.sleep(1.0) + + except Exception as e: + logger.error(f"Error in CNN prediction worker: {e}") + time.sleep(5.0) # Wait longer on error + + # Start the worker thread + import threading + import time + prediction_thread = threading.Thread(target=cnn_prediction_worker, daemon=True) + prediction_thread.start() + + logger.info("CNN real-time prediction loop started") + + except Exception as e: + logger.error(f"Error starting CNN prediction loop: {e}") + + def _add_cnn_training_sample(self, symbol: str, prediction: Dict[str, Any]): + """Add CNN training sample based on prediction outcome""" + try: + if not self.cnn_adapter or not hasattr(self.cnn_adapter, 'add_training_sample'): + return + + # Get current price for reward calculation + current_price = self._get_current_price(symbol) + if not current_price: + return + + # Calculate reward based on prediction accuracy (simplified) + # In a real implementation, this would be based on actual market movement + action = prediction['action'] + confidence = prediction['confidence'] + + # Simple reward: higher confidence predictions get higher rewards + base_reward = confidence * 0.1 + + # Add some market context (price movement direction) + price_history = self._get_recent_price_history(symbol, 10) + if len(price_history) >= 2: + price_change = (price_history[-1] - price_history[-2]) / price_history[-2] + + # Reward if prediction aligns with price movement + if (action == 'BUY' and price_change > 0) or (action == 'SELL' and price_change < 0): + reward = base_reward * 1.5 # Bonus for correct direction + else: + reward = base_reward * 0.5 # Penalty for wrong direction + else: + reward = base_reward + + # Add training sample + self.cnn_adapter.add_training_sample(symbol, action, reward) + + logger.debug(f"Added CNN training sample: {symbol} {action} (reward: {reward:.4f})") + + except Exception as e: + logger.error(f"Error adding CNN training sample: {e}") + + def _get_recent_price_history(self, symbol: str, count: int) -> List[float]: + """Get recent price history for reward calculation""" + try: + df = self.data_provider.get_candles(symbol, '1s') + if df is None or len(df) == 0: + return [] + + return df['close'].tail(count).tolist() + + except Exception as e: + logger.error(f"Error getting price history for {symbol}: {e}") + return [] def _initialize_enhanced_position_sync(self): """Initialize enhanced position synchronization system"""