diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 521c0a4..7feeb6a 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -1164,35 +1164,23 @@ class DQNAgent: # Check if state is a dict or complex object if isinstance(state, dict): logger.error(f"State is a dict: {state}") + + # Handle empty dictionary case + if not state: + logger.error("No numerical values found in state dict, using default state") + expected_size = getattr(self, 'state_size', 403) + if isinstance(expected_size, tuple): + expected_size = np.prod(expected_size) + return np.zeros(int(expected_size), dtype=np.float32) + # Extract numerical values from dict if possible if 'features' in state: state = state['features'] elif 'state' in state: state = state['state'] else: - # Try to extract all numerical values - numerical_values = [] - for key, value in state.items(): - if isinstance(value, (int, float)): - numerical_values.append(float(value)) - elif isinstance(value, (list, np.ndarray)): - try: - # Handle nested structures safely - flattened = np.array(value).flatten() - for x in flattened: - if isinstance(x, (int, float)): - numerical_values.append(float(x)) - elif hasattr(x, 'item'): # numpy scalar - numerical_values.append(float(x.item())) - except (ValueError, TypeError): - continue - elif isinstance(value, dict): - # Recursively extract from nested dicts - try: - nested_values = self._extract_numeric_from_dict(value) - numerical_values.extend(nested_values) - except Exception: - continue + # Try to extract all numerical values using the helper method + numerical_values = self._extract_numeric_from_dict(state) if numerical_values: state = np.array(numerical_values, dtype=np.float32) else: @@ -1254,6 +1242,31 @@ class DQNAgent: expected_size = np.prod(expected_size) return np.zeros(int(expected_size), dtype=np.float32) + def _extract_numeric_from_dict(self, data_dict): + """Recursively extract numerical values from nested dictionaries""" + numerical_values = [] + try: + for key, value in data_dict.items(): + if isinstance(value, (int, float)): + numerical_values.append(float(value)) + elif isinstance(value, (list, np.ndarray)): + try: + flattened = np.array(value).flatten() + for x in flattened: + if isinstance(x, (int, float)): + numerical_values.append(float(x)) + elif hasattr(x, 'item'): # numpy scalar + numerical_values.append(float(x.item())) + except (ValueError, TypeError): + continue + elif isinstance(value, dict): + # Recursively extract from nested dicts + nested_values = self._extract_numeric_from_dict(value) + numerical_values.extend(nested_values) + except Exception as e: + logger.debug(f"Error extracting numeric values from dict: {e}") + return numerical_values + def _replay_standard(self, states, actions, rewards, next_states, dones): """Standard training step without mixed precision""" try: diff --git a/core/data_provider.py b/core/data_provider.py index 1ebbd02..1f417a9 100644 --- a/core/data_provider.py +++ b/core/data_provider.py @@ -83,6 +83,13 @@ class PivotBounds: distances = [abs(current_price - r) for r in self.pivot_resistance_levels] return min(distances) / self.get_price_range() +@dataclass +class SimplePivotLevel: + """Simple pivot level structure for fallback pivot detection""" + swing_points: List[Any] = field(default_factory=list) + support_levels: List[float] = field(default_factory=list) + resistance_levels: List[float] = field(default_factory=list) + @dataclass class MarketTick: """Standardized market tick data structure""" @@ -127,6 +134,10 @@ class DataProvider: self.real_time_data = {} # {symbol: {timeframe: deque}} self.current_prices = {} # {symbol: float} + # Live price cache for low-latency price updates + self.live_price_cache: Dict[str, Tuple[float, datetime]] = {} + self.live_price_cache_ttl = timedelta(milliseconds=500) + # Initialize cached data structure for symbol in self.symbols: self.cached_data[symbol] = {} @@ -1839,14 +1850,14 @@ class DataProvider: low_pivots = monthly_data[lows == rolling_min]['low'].tolist() pivot_lows.extend(low_pivots) - # Create mock level structure - mock_level = type('MockLevel', (), { - 'swing_points': [], - 'support_levels': list(set(pivot_lows)), - 'resistance_levels': list(set(pivot_highs)) - })() + # Create proper pivot level structure + pivot_level = SimplePivotLevel( + swing_points=[], + support_levels=list(set(pivot_lows)), + resistance_levels=list(set(pivot_highs)) + ) - return {'level_0': mock_level} + return {'level_0': pivot_level} except Exception as e: logger.error(f"Error in simple pivot detection: {e}") diff --git a/core/multi_exchange_cob_provider.py b/core/multi_exchange_cob_provider.py index 8f71503..f3a7d90 100644 --- a/core/multi_exchange_cob_provider.py +++ b/core/multi_exchange_cob_provider.py @@ -1062,10 +1062,11 @@ class MultiExchangeCOBProvider: consolidated_bids[price].exchange_breakdown[exchange_name] = level # Update dominant exchange based on volume - if level.volume_usd > consolidated_bids[price].exchange_breakdown.get( - consolidated_bids[price].dominant_exchange, - type('obj', (object,), {'volume_usd': 0})() - ).volume_usd: + current_dominant = consolidated_bids[price].exchange_breakdown.get( + consolidated_bids[price].dominant_exchange + ) + current_volume = current_dominant.volume_usd if current_dominant else 0 + if level.volume_usd > current_volume: consolidated_bids[price].dominant_exchange = exchange_name # Process merged asks (similar logic) @@ -1088,10 +1089,11 @@ class MultiExchangeCOBProvider: consolidated_asks[price].total_orders += level.orders_count consolidated_asks[price].exchange_breakdown[exchange_name] = level - if level.volume_usd > consolidated_asks[price].exchange_breakdown.get( - consolidated_asks[price].dominant_exchange, - type('obj', (object,), {'volume_usd': 0})() - ).volume_usd: + current_dominant = consolidated_asks[price].exchange_breakdown.get( + consolidated_asks[price].dominant_exchange + ) + current_volume = current_dominant.volume_usd if current_dominant else 0 + if level.volume_usd > current_volume: consolidated_asks[price].dominant_exchange = exchange_name logger.debug(f"Consolidated {len(consolidated_bids)} bids and {len(consolidated_asks)} asks for {symbol}") diff --git a/core/orchestrator.py b/core/orchestrator.py index cb4299b..5a0ea88 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -1493,6 +1493,17 @@ class TradingOrchestrator: if not base_data: logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}") return predictions + + # Validate base_data has proper feature vector + if hasattr(base_data, 'get_feature_vector'): + try: + feature_vector = base_data.get_feature_vector() + if feature_vector is None or (isinstance(feature_vector, np.ndarray) and feature_vector.size == 0): + logger.warning(f"BaseDataInput has empty feature vector for {symbol}") + return predictions + except Exception as e: + logger.warning(f"Error getting feature vector from BaseDataInput for {symbol}: {e}") + return predictions # log all registered models logger.debug(f"inferencing registered models: {self.model_registry.models}") @@ -1691,6 +1702,15 @@ class TradingOrchestrator: try: logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})") + # Validate model_input before storing + if model_input is None: + logger.warning(f"Skipping inference storage for {model_name}: model_input is None") + return + + if isinstance(model_input, dict) and not model_input: + logger.warning(f"Skipping inference storage for {model_name}: model_input is empty dict") + return + # Extract symbol from prediction if not provided if symbol is None: symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available @@ -2569,6 +2589,25 @@ class TradingOrchestrator: # Method 3: Dictionary with feature data if isinstance(model_input, dict): + # Check if dictionary is empty + if not model_input: + logger.warning(f"Empty dictionary passed as model_input for {model_name}, using fallback") + # Try to use data provider to build state as fallback + if hasattr(self, 'data_provider'): + try: + base_data = self.data_provider.build_base_data_input('ETH/USDT') + if base_data and hasattr(base_data, 'get_feature_vector'): + state = base_data.get_feature_vector() + if isinstance(state, np.ndarray): + logger.debug(f"Used data provider fallback for empty dict in {model_name}") + return state + except Exception as e: + logger.debug(f"Data provider fallback failed for empty dict in {model_name}: {e}") + + # Final fallback: return default state + logger.warning(f"Using default state for empty dict in {model_name}") + return np.zeros(403, dtype=np.float32) # Default state size + # Try to extract features from dictionary if 'features' in model_input: features = model_input['features'] @@ -2589,6 +2628,8 @@ class TradingOrchestrator: if feature_list: return np.array(feature_list, dtype=np.float32) + else: + logger.warning(f"No numerical features found in dictionary for {model_name}, using fallback") # Method 4: List or tuple if isinstance(model_input, (list, tuple)): diff --git a/main.py b/main.py index 71e83e1..bde3748 100644 --- a/main.py +++ b/main.py @@ -65,16 +65,27 @@ async def run_web_dashboard(): except Exception as e: logger.warning(f"[WARNING] Real-time streaming failed: {e}") - # Verify data connection + # Verify data connection with retry mechanism logger.info("[DATA] Verifying live data connection...") symbol = config.get('symbols', ['ETH/USDT'])[0] - test_df = data_provider.get_historical_data(symbol, '1m', limit=10) - if test_df is not None and len(test_df) > 0: - logger.info("[SUCCESS] Data connection verified") - logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation") - else: - logger.error("[ERROR] Data connection failed - no live data available") - return + + # Wait for data provider to initialize and fetch initial data + max_retries = 10 + retry_delay = 2 + + for attempt in range(max_retries): + test_df = data_provider.get_historical_data(symbol, '1m', limit=10) + if test_df is not None and len(test_df) > 0: + logger.info("[SUCCESS] Data connection verified") + logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation") + break + else: + if attempt < max_retries - 1: + logger.info(f"[DATA] Waiting for data provider to initialize... (attempt {attempt + 1}/{max_retries})") + await asyncio.sleep(retry_delay) + else: + logger.warning("[WARNING] Data connection verification failed, but continuing with system startup") + logger.warning("The system will attempt to fetch data as needed during operation") # Load model registry for integrated pipeline try: @@ -122,6 +133,7 @@ async def run_web_dashboard(): logger.info("Starting training loop...") # Start the training loop + logger.info("About to start training loop...") await start_training_loop(orchestrator, trading_executor) except Exception as e: @@ -207,6 +219,8 @@ async def start_training_loop(orchestrator, trading_executor): logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION") logger.info("=" * 70) + logger.info("Training loop function entered successfully") + # Initialize checkpoint management for training loop checkpoint_manager = get_checkpoint_manager() training_integration = get_training_integration() @@ -222,8 +236,10 @@ async def start_training_loop(orchestrator, trading_executor): try: # Start real-time processing (Basic orchestrator doesn't have this method) + logger.info("Checking for real-time processing capabilities...") try: if hasattr(orchestrator, 'start_realtime_processing'): + logger.info("Starting real-time processing...") await orchestrator.start_realtime_processing() logger.info("Real-time processing started") else: @@ -231,6 +247,8 @@ async def start_training_loop(orchestrator, trading_executor): except Exception as e: logger.warning(f"Real-time processing not available: {e}") + logger.info("About to enter main training loop...") + # Main training loop iteration = 0 while True: diff --git a/utils/checkpoint_manager.py b/utils/checkpoint_manager.py index ba989dc..844d1c5 100644 --- a/utils/checkpoint_manager.py +++ b/utils/checkpoint_manager.py @@ -491,4 +491,57 @@ class CheckpointManager: except Exception as e: logger.error(f"Error getting all checkpoints: {e}") - return [] \ No newline at end of file + return [] + + def get_checkpoint_stats(self) -> Dict[str, Any]: + """ + Get statistics about all checkpoints + + Returns: + Dict[str, Any]: Statistics about checkpoints + """ + try: + stats = { + 'total_checkpoints': 0, + 'total_size_mb': 0.0, + 'models': {} + } + + # Iterate through all model directories + for model_dir in os.listdir(self.checkpoint_dir): + model_path = os.path.join(self.checkpoint_dir, model_dir) + if not os.path.isdir(model_path): + continue + + # Count checkpoints for this model + checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt")) + model_checkpoints = len(checkpoint_files) + + # Calculate total size for this model + model_size_mb = 0.0 + for checkpoint_file in checkpoint_files: + try: + size_bytes = os.path.getsize(checkpoint_file) + model_size_mb += size_bytes / (1024 * 1024) # Convert to MB + except OSError: + pass + + stats['models'][model_dir] = { + 'checkpoints': model_checkpoints, + 'size_mb': round(model_size_mb, 2) + } + + stats['total_checkpoints'] += model_checkpoints + stats['total_size_mb'] += model_size_mb + + stats['total_size_mb'] = round(stats['total_size_mb'], 2) + + return stats + + except Exception as e: + logger.error(f"Error getting checkpoint stats: {e}") + return { + 'total_checkpoints': 0, + 'total_size_mb': 0.0, + 'models': {} + } \ No newline at end of file diff --git a/utils/training_integration.py b/utils/training_integration.py index 402b066..3e7ee89 100644 --- a/utils/training_integration.py +++ b/utils/training_integration.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) class TrainingIntegration: def __init__(self, enable_wandb: bool = True): + self.enable_wandb = enable_wandb self.checkpoint_manager = get_checkpoint_manager() @@ -55,9 +56,13 @@ class TrainingIntegration: except Exception as e: logger.warning(f"Error logging to W&B: {e}") + # Save the model first to get the path + model_path = f"models/{model_name}_temp.pt" + torch.save(cnn_model.state_dict(), model_path) + metadata = self.checkpoint_manager.save_checkpoint( - model=cnn_model, model_name=model_name, + model_path=model_path, model_type='cnn', performance_metrics=performance_metrics, training_metadata=training_metadata @@ -114,9 +119,13 @@ class TrainingIntegration: except Exception as e: logger.warning(f"Error logging to W&B: {e}") + # Save the model first to get the path + model_path = f"models/{model_name}_temp.pt" + torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path) + metadata = self.checkpoint_manager.save_checkpoint( - model=rl_agent, model_name=model_name, + model_path=model_path, model_type='rl', performance_metrics=performance_metrics, training_metadata=training_metadata diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index 181ea7c..5a223c1 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -6056,6 +6056,7 @@ class CleanTradingDashboard: # Fallback: create BaseDataInput from available data from core.data_models import BaseDataInput, OHLCVBar, COBData + import random # Get OHLCV data for different timeframes - ensure we have enough data ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300) @@ -6073,7 +6074,6 @@ class CleanTradingDashboard: if len(bars) > 0: last_bar = bars[-1] # Add small random variation to prevent identical data - import random for i in range(target_count - len(bars)): # Create slight variations of the last bar variation = random.uniform(-0.001, 0.001) # 0.1% variation @@ -6090,7 +6090,6 @@ class CleanTradingDashboard: bars.append(new_bar) else: # Create realistic dummy bars with variation - from core.data_models import OHLCVBar base_price = 3500.0 for i in range(target_count): # Add realistic price movement @@ -8725,6 +8724,14 @@ def signal_handler(sig, frame): self.shutdown() # Assuming a shutdown method exists or add one sys.exit(0) -signal.signal(signal.SIGTERM, signal_handler) -signal.signal(signal.SIGINT, signal_handler) +# Only set signal handlers if we're in the main thread +try: + import threading + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + else: + print("Warning: Signal handlers can only be set in main thread, skipping...") +except Exception as e: + print(f"Warning: Could not set signal handlers: {e}")