diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 993a156..6ef63a9 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -113,6 +113,15 @@ class DQNAgent: # Initialize avg_reward for dashboard compatibility self.avg_reward = 0.0 # Average reward tracking for dashboard + # Market regime adaptation weights + self.market_regime_weights = { + 'trending': 1.0, + 'sideways': 0.8, + 'volatile': 1.2, + 'bullish': 1.1, + 'bearish': 1.1 + } + # Load best checkpoint if available if self.enable_checkpoints: self.load_best_checkpoint() @@ -490,7 +499,17 @@ class DQNAgent: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) q_values = self.policy_net(state_tensor) - # Ensure q_values has correct shape for softmax + # Handle case where network might return a tuple instead of tensor + if isinstance(q_values, tuple): + # If it's a tuple, take the first element (usually the main output) + q_values = q_values[0] + + # Ensure q_values is a tensor and has correct shape for softmax + if not hasattr(q_values, 'dim'): + logger.error(f"DQN: q_values is not a tensor: {type(q_values)}") + # Return default action with low confidence + return 1, 0.1 # Default to HOLD action + if q_values.dim() == 1: q_values = q_values.unsqueeze(0) diff --git a/NN/models/enhanced_cnn.py b/NN/models/enhanced_cnn.py index 735a50b..722aec8 100644 --- a/NN/models/enhanced_cnn.py +++ b/NN/models/enhanced_cnn.py @@ -117,52 +117,52 @@ class EnhancedCNN(nn.Module): # Ultra massive convolutional backbone with much deeper residual blocks self.conv_layers = nn.Sequential( # Initial ultra large conv block - nn.Conv1d(self.channels, 512, kernel_size=7, padding=3), # Ultra wide initial layer - nn.BatchNorm1d(512), + nn.Conv1d(self.channels, 1024, kernel_size=7, padding=3), # Ultra wide initial layer (increased from 512) + nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.1), - # First residual stage - 512 channels - ResidualBlock(512, 768), - ResidualBlock(768, 768), - ResidualBlock(768, 768), - ResidualBlock(768, 768), # Additional layer - nn.MaxPool1d(kernel_size=2, stride=2), - nn.Dropout(0.2), - - # Second residual stage - 768 to 1024 channels - ResidualBlock(768, 1024), - ResidualBlock(1024, 1024), - ResidualBlock(1024, 1024), - ResidualBlock(1024, 1024), # Additional layer - nn.MaxPool1d(kernel_size=2, stride=2), - nn.Dropout(0.25), - - # Third residual stage - 1024 to 1536 channels - ResidualBlock(1024, 1536), + # First residual stage - 1024 channels (increased from 512) + ResidualBlock(1024, 1536), # Increased from 768 ResidualBlock(1536, 1536), ResidualBlock(1536, 1536), ResidualBlock(1536, 1536), # Additional layer nn.MaxPool1d(kernel_size=2, stride=2), - nn.Dropout(0.3), + nn.Dropout(0.2), - # Fourth residual stage - 1536 to 2048 channels + # Second residual stage - 1536 to 2048 channels (increased from 768 to 1024) ResidualBlock(1536, 2048), ResidualBlock(2048, 2048), ResidualBlock(2048, 2048), ResidualBlock(2048, 2048), # Additional layer nn.MaxPool1d(kernel_size=2, stride=2), - nn.Dropout(0.3), + nn.Dropout(0.25), - # Fifth residual stage - ULTRA MASSIVE 2048 to 3072 channels + # Third residual stage - 2048 to 3072 channels (increased from 1024 to 1536) ResidualBlock(2048, 3072), ResidualBlock(3072, 3072), ResidualBlock(3072, 3072), - ResidualBlock(3072, 3072), + ResidualBlock(3072, 3072), # Additional layer + nn.MaxPool1d(kernel_size=2, stride=2), + nn.Dropout(0.3), + + # Fourth residual stage - 3072 to 4096 channels (increased from 1536 to 2048) + ResidualBlock(3072, 4096), + ResidualBlock(4096, 4096), + ResidualBlock(4096, 4096), + ResidualBlock(4096, 4096), # Additional layer + nn.MaxPool1d(kernel_size=2, stride=2), + nn.Dropout(0.3), + + # Fifth residual stage - ULTRA MASSIVE 4096 to 6144 channels (increased from 2048 to 3072) + ResidualBlock(4096, 6144), + ResidualBlock(6144, 6144), + ResidualBlock(6144, 6144), + ResidualBlock(6144, 6144), nn.AdaptiveAvgPool1d(1) # Global average pooling ) # Ultra massive feature dimension after conv layers - self.conv_features = 3072 + self.conv_features = 6144 # Increased from 3072 else: # For 1D vectors, use ultra massive dense preprocessing self.conv_layers = None @@ -171,36 +171,36 @@ class EnhancedCNN(nn.Module): # ULTRA MASSIVE fully connected feature extraction layers if self.conv_layers is None: # For 1D inputs - ultra massive feature extraction - self.fc1 = nn.Linear(self.feature_dim, 3072) - self.features_dim = 3072 + self.fc1 = nn.Linear(self.feature_dim, 6144) # Increased from 3072 + self.features_dim = 6144 # Increased from 3072 else: # For data processed by ultra massive conv layers - self.fc1 = nn.Linear(self.conv_features, 3072) - self.features_dim = 3072 + self.fc1 = nn.Linear(self.conv_features, 6144) # Increased from 3072 + self.features_dim = 6144 # Increased from 3072 # ULTRA MASSIVE common feature extraction with multiple deep layers self.fc_layers = nn.Sequential( self.fc1, nn.ReLU(), nn.Dropout(0.3), - nn.Linear(3072, 3072), # Keep ultra massive width + nn.Linear(6144, 6144), # Keep ultra massive width (increased from 3072) nn.ReLU(), nn.Dropout(0.3), - nn.Linear(3072, 2560), # Ultra wide hidden layer + nn.Linear(6144, 4096), # Ultra wide hidden layer (increased from 2560) nn.ReLU(), nn.Dropout(0.3), - nn.Linear(2560, 2048), # Still very wide + nn.Linear(4096, 3072), # Still very wide (increased from 2048) nn.ReLU(), nn.Dropout(0.3), - nn.Linear(2048, 1536), # Large hidden layer + nn.Linear(3072, 2048), # Large hidden layer (increased from 1536) nn.ReLU(), nn.Dropout(0.3), - nn.Linear(1536, 1024), # Final feature representation + nn.Linear(2048, 1024), # Final feature representation (increased from 1024, but keeping the same value to align with attention layers) nn.ReLU() ) - # Multiple attention mechanisms for different aspects (larger capacity) - self.price_attention = SelfAttention(1024) # Increased from 768 + # Multiple specialized attention mechanisms (larger capacity) + self.price_attention = SelfAttention(1024) # Keeping 1024 self.volume_attention = SelfAttention(1024) self.trend_attention = SelfAttention(1024) self.volatility_attention = SelfAttention(1024) @@ -209,108 +209,108 @@ class EnhancedCNN(nn.Module): # Ultra massive attention fusion layer self.attention_fusion = nn.Sequential( - nn.Linear(1024 * 6, 2048), # Combine all 6 attention outputs + nn.Linear(1024 * 6, 4096), # Combine all 6 attention outputs (increased from 2048) nn.ReLU(), nn.Dropout(0.3), - nn.Linear(2048, 1536), + nn.Linear(4096, 3072), # Increased from 1536 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(1536, 1024) + nn.Linear(3072, 1024) # Keeping 1024 ) # ULTRA MASSIVE dueling architecture with much deeper networks self.advantage_stream = nn.Sequential( - nn.Linear(1024, 768), + nn.Linear(1024, 1536), # Increased from 768 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(768, 512), + nn.Linear(1536, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 512), # Increased from 256 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(256, 128), + nn.Linear(512, 256), # Increased from 128 nn.ReLU(), - nn.Linear(128, self.n_actions) + nn.Linear(256, self.n_actions) ) self.value_stream = nn.Sequential( - nn.Linear(1024, 768), + nn.Linear(1024, 1536), # Increased from 768 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(768, 512), + nn.Linear(1536, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 512), # Increased from 256 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(256, 128), + nn.Linear(512, 256), # Increased from 128 nn.ReLU(), - nn.Linear(128, 1) + nn.Linear(256, 1) ) # ULTRA MASSIVE extrema detection head with deeper ensemble predictions self.extrema_head = nn.Sequential( - nn.Linear(1024, 768), + nn.Linear(1024, 1536), # Increased from 768 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(768, 512), + nn.Linear(1536, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 512), # Increased from 256 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(256, 128), + nn.Linear(512, 256), # Increased from 128 nn.ReLU(), - nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither + nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither ) # ULTRA MASSIVE multi-timeframe price prediction heads self.price_pred_immediate = nn.Sequential( - nn.Linear(1024, 512), + nn.Linear(1024, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 512), # Increased from 256 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(256, 128), + nn.Linear(512, 256), # Increased from 128 nn.ReLU(), - nn.Linear(128, 3) # Up, Down, Sideways + nn.Linear(256, 3) # Up, Down, Sideways ) self.price_pred_midterm = nn.Sequential( - nn.Linear(1024, 512), + nn.Linear(1024, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 512), # Increased from 256 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(256, 128), + nn.Linear(512, 256), # Increased from 128 nn.ReLU(), - nn.Linear(128, 3) # Up, Down, Sideways + nn.Linear(256, 3) # Up, Down, Sideways ) self.price_pred_longterm = nn.Sequential( - nn.Linear(1024, 512), + nn.Linear(1024, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 512), # Increased from 256 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(256, 128), + nn.Linear(512, 256), # Increased from 128 nn.ReLU(), - nn.Linear(128, 3) # Up, Down, Sideways + nn.Linear(256, 3) # Up, Down, Sideways ) # ULTRA MASSIVE value prediction with ensemble approaches self.price_pred_value = nn.Sequential( - nn.Linear(1024, 768), + nn.Linear(1024, 1536), # Increased from 768 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(768, 512), + nn.Linear(1536, 1024), # Increased from 512 nn.ReLU(), nn.Dropout(0.3), - nn.Linear(512, 256), + nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 128), @@ -391,7 +391,7 @@ class EnhancedCNN(nn.Module): # Handle 4D input [batch, timeframes, window, features] or 3D input [batch, timeframes, features] if len(x.shape) == 4: # Flatten window and features: [batch, timeframes, window*features] - x = x.view(batch_size, x.size(1), -1) + x = x.reshape(batch_size, x.size(1), -1) if self.conv_layers is not None: # Now x is 3D: [batch, timeframes, features] @@ -405,10 +405,10 @@ class EnhancedCNN(nn.Module): # Apply ultra massive convolutions x_conv = self.conv_layers(x_reshaped) # Flatten: [batch, channels, 1] -> [batch, channels] - x_flat = x_conv.view(batch_size, -1) + x_flat = x_conv.reshape(batch_size, -1) else: # If no conv layers, just flatten - x_flat = x.view(batch_size, -1) + x_flat = x.reshape(batch_size, -1) else: # For 2D input [batch, features] x_flat = x @@ -512,30 +512,30 @@ class EnhancedCNN(nn.Module): # Log advanced predictions for better decision making if hasattr(self, '_log_predictions') and self._log_predictions: # Log volatility prediction - volatility = torch.softmax(advanced_predictions['volatility'], dim=1) - volatility_class = torch.argmax(volatility, dim=1).item() + volatility = torch.softmax(advanced_predictions['volatility'], dim=1).squeeze(0) + volatility_class = int(torch.argmax(volatility).item()) volatility_labels = ['Very Low', 'Low', 'Medium', 'High', 'Very High'] # Log support/resistance prediction - sr = torch.softmax(advanced_predictions['support_resistance'], dim=1) - sr_class = torch.argmax(sr, dim=1).item() + sr = torch.softmax(advanced_predictions['support_resistance'], dim=1).squeeze(0) + sr_class = int(torch.argmax(sr).item()) sr_labels = ['Strong Support', 'Weak Support', 'Neutral', 'Weak Resistance', 'Strong Resistance', 'Breakout'] # Log market regime prediction - regime = torch.softmax(advanced_predictions['market_regime'], dim=1) - regime_class = torch.argmax(regime, dim=1).item() + regime = torch.softmax(advanced_predictions['market_regime'], dim=1).squeeze(0) + regime_class = int(torch.argmax(regime).item()) regime_labels = ['Bull Trend', 'Bear Trend', 'Sideways', 'Volatile Up', 'Volatile Down', 'Accumulation', 'Distribution'] # Log risk assessment - risk = torch.softmax(advanced_predictions['risk_assessment'], dim=1) - risk_class = torch.argmax(risk, dim=1).item() + risk = torch.softmax(advanced_predictions['risk_assessment'], dim=1).squeeze(0) + risk_class = int(torch.argmax(risk).item()) risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk'] logger.info(f"ULTRA MASSIVE Model Predictions:") - logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})") - logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})") - logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})") - logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[0, risk_class]:.3f})") + logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[volatility_class]:.3f})") + logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[sr_class]:.3f})") + logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})") + logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})") return action diff --git a/config.yaml b/config.yaml index eeaac27..d3de456 100644 --- a/config.yaml +++ b/config.yaml @@ -200,7 +200,7 @@ enhanced_training: min_training_samples: 100 # Minimum samples before training starts adaptation_threshold: 0.1 # Performance threshold for adaptation forward_looking_predictions: true # Enable forward-looking prediction validation - + # Real-time RL COB Trader Configuration realtime_rl: # Model parameters for 400M parameter network (faster startup) diff --git a/core/cob_integration.py b/core/cob_integration.py index df51db7..f8db1a7 100644 --- a/core/cob_integration.py +++ b/core/cob_integration.py @@ -34,7 +34,7 @@ class COBIntegration: Integration layer for Multi-Exchange COB data with gogo2 trading system """ - def __init__(self, data_provider: DataProvider = None, symbols: List[str] = None): + def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None): """ Initialize COB Integration @@ -45,15 +45,8 @@ class COBIntegration: self.data_provider = data_provider self.symbols = symbols or ['BTC/USDT', 'ETH/USDT'] - # Initialize COB provider - self.cob_provider = MultiExchangeCOBProvider( - symbols=self.symbols, - bucket_size_bps=1.0 # 1 basis point granularity - ) - - # Register callbacks - self.cob_provider.subscribe_to_cob_updates(self._on_cob_update) - self.cob_provider.subscribe_to_bucket_updates(self._on_bucket_update) + # Initialize COB provider to None, will be set in start() + self.cob_provider = None # CNN/DQN integration self.cnn_callbacks: List[Callable] = [] @@ -75,13 +68,23 @@ class COBIntegration: self.liquidity_alerts[symbol] = [] self.arbitrage_opportunities[symbol] = [] - logger.info("COB Integration initialized") + logger.info("COB Integration initialized (provider will be started in async)") logger.info(f"Symbols: {self.symbols}") async def start(self): """Start COB integration""" logger.info("Starting COB Integration") + # Initialize COB provider here, within the async context + self.cob_provider = MultiExchangeCOBProvider( + symbols=self.symbols, + bucket_size_bps=1.0 # 1 basis point granularity + ) + + # Register callbacks + self.cob_provider.subscribe_to_cob_updates(self._on_cob_update) + self.cob_provider.subscribe_to_bucket_updates(self._on_bucket_update) + # Start COB provider await self.cob_provider.start_streaming() @@ -94,7 +97,8 @@ class COBIntegration: async def stop(self): """Stop COB integration""" logger.info("Stopping COB Integration") - await self.cob_provider.stop_streaming() + if self.cob_provider: + await self.cob_provider.stop_streaming() logger.info("COB Integration stopped") def add_cnn_callback(self, callback: Callable[[str, Dict], None]): @@ -293,7 +297,9 @@ class COBIntegration: """Generate formatted data for dashboard visualization""" try: # Get fixed bucket size for the symbol - bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0) + bucket_size = 1.0 # Default bucket size + if self.cob_provider: + bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0) # Calculate price range for buckets mid_price = cob_snapshot.volume_weighted_mid @@ -338,15 +344,16 @@ class COBIntegration: # Get actual Session Volume Profile (SVP) from trade data svp_data = [] - try: - svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size) - if svp_result and 'data' in svp_result: - svp_data = svp_result['data'] - logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels") - else: - logger.warning(f"No SVP data available for {symbol}") - except Exception as e: - logger.error(f"Error getting SVP data for {symbol}: {e}") + if self.cob_provider: + try: + svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size) + if svp_result and 'data' in svp_result: + svp_data = svp_result['data'] + logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels") + else: + logger.warning(f"No SVP data available for {symbol}") + except Exception as e: + logger.error(f"Error getting SVP data for {symbol}: {e}") # Generate market stats stats = { @@ -381,19 +388,21 @@ class COBIntegration: stats['svp_price_levels'] = 0 stats['session_start'] = '' - # Add real-time statistics for NN models - try: - realtime_stats = self.cob_provider.get_realtime_stats(symbol) - if realtime_stats: - stats['realtime_1s'] = realtime_stats.get('1s_stats', {}) - stats['realtime_5s'] = realtime_stats.get('5s_stats', {}) - else: + # Get additional real-time stats + realtime_stats = {} + if self.cob_provider: + try: + realtime_stats = self.cob_provider.get_realtime_stats(symbol) + if realtime_stats: + stats['realtime_1s'] = realtime_stats.get('1s_stats', {}) + stats['realtime_5s'] = realtime_stats.get('5s_stats', {}) + else: + stats['realtime_1s'] = {} + stats['realtime_5s'] = {} + except Exception as e: + logger.error(f"Error getting real-time stats for {symbol}: {e}") stats['realtime_1s'] = {} stats['realtime_5s'] = {} - except Exception as e: - logger.error(f"Error getting real-time stats for {symbol}: {e}") - stats['realtime_1s'] = {} - stats['realtime_5s'] = {} return { 'type': 'cob_update', @@ -463,9 +472,10 @@ class COBIntegration: while True: try: for symbol in self.symbols: - cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol) - if cob_snapshot: - await self._analyze_cob_patterns(symbol, cob_snapshot) + if self.cob_provider: + cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol) + if cob_snapshot: + await self._analyze_cob_patterns(symbol, cob_snapshot) await asyncio.sleep(1) @@ -540,18 +550,26 @@ class COBIntegration: def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]: """Get latest COB snapshot for a symbol""" + if not self.cob_provider: + return None return self.cob_provider.get_consolidated_orderbook(symbol) def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]: """Get detailed market depth analysis""" + if not self.cob_provider: + return None return self.cob_provider.get_market_depth_analysis(symbol) def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]: """Get liquidity breakdown by exchange""" + if not self.cob_provider: + return None return self.cob_provider.get_exchange_breakdown(symbol) def get_price_buckets(self, symbol: str) -> Optional[Dict]: """Get fine-grain price buckets""" + if not self.cob_provider: + return None return self.cob_provider.get_price_buckets(symbol) def get_recent_signals(self, symbol: str, count: int = 20) -> List[Dict]: @@ -560,6 +578,16 @@ class COBIntegration: def get_statistics(self) -> Dict[str, Any]: """Get COB integration statistics""" + if not self.cob_provider: + return { + 'cnn_callbacks': len(self.cnn_callbacks), + 'dqn_callbacks': len(self.dqn_callbacks), + 'dashboard_callbacks': len(self.dashboard_callbacks), + 'cached_features': list(self.cob_feature_cache.keys()), + 'total_signals': {symbol: len(signals) for symbol, signals in self.cob_signals.items()}, + 'provider_status': 'Not initialized' + } + provider_stats = self.cob_provider.get_statistics() return { @@ -574,6 +602,11 @@ class COBIntegration: def get_realtime_stats_for_nn(self, symbol: str) -> Dict: """Get real-time statistics formatted for NN models""" try: + # Check if COB provider is initialized + if not self.cob_provider: + logger.debug(f"COB provider not initialized yet for {symbol}") + return {} + realtime_stats = self.cob_provider.get_realtime_stats(symbol) if not realtime_stats: return {} @@ -608,4 +641,66 @@ class COBIntegration: except Exception as e: logger.error(f"Error getting NN stats for {symbol}: {e}") - return {} \ No newline at end of file + return {} + + def get_realtime_stats(self): + # Added null check to ensure the COB provider is initialized + if self.cob_provider is None: + logger.warning("COB provider is uninitialized; attempting initialization.") + self.initialize_provider() + if self.cob_provider is None: + logger.error("COB provider failed to initialize; returning default empty snapshot.") + return COBSnapshot( + symbol="", + timestamp=0, + exchanges_active=0, + total_bid_liquidity=0, + total_ask_liquidity=0, + price_buckets=[], + volume_weighted_mid=0, + spread_bps=0, + liquidity_imbalance=0, + consolidated_bids=[], + consolidated_asks=[] + ) + try: + snapshot = self.cob_provider.get_realtime_stats() + return snapshot + except Exception as e: + logger.error(f"Error retrieving COB snapshot: {e}") + return COBSnapshot( + symbol="", + timestamp=0, + exchanges_active=0, + total_bid_liquidity=0, + total_ask_liquidity=0, + price_buckets=[], + volume_weighted_mid=0, + spread_bps=0, + liquidity_imbalance=0, + consolidated_bids=[], + consolidated_asks=[] + ) + + def stop_streaming(self): + pass + + def _initialize_cob_integration(self): + """Initialize COB integration with high-frequency data handling""" + logger.info("Initializing COB integration...") + if not COB_INTEGRATION_AVAILABLE: + logger.warning("COB integration not available - skipping initialization") + return + + try: + if not hasattr(self.orchestrator, 'cob_integration') or self.orchestrator.cob_integration is None: + logger.info("Creating new COB integration instance") + self.orchestrator.cob_integration = COBIntegration(self.data_provider) + else: + logger.info("Using existing COB integration from orchestrator") + + # Start simple COB data collection for both symbols + self._start_simple_cob_collection() + logger.info("COB integration initialized successfully") + except Exception as e: + logger.error(f"Error initializing COB integration: {e}") \ No newline at end of file diff --git a/core/multi_exchange_cob_provider.py b/core/multi_exchange_cob_provider.py index 496c69c..96fbff7 100644 --- a/core/multi_exchange_cob_provider.py +++ b/core/multi_exchange_cob_provider.py @@ -33,7 +33,7 @@ except ImportError: import numpy as np import pandas as pd from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Any, Callable, Union +from typing import Dict, List, Optional, Tuple, Any, Callable, Union, Awaitable from collections import deque, defaultdict from dataclasses import dataclass, field from threading import Thread, Lock @@ -194,6 +194,11 @@ class MultiExchangeCOBProvider: # Thread safety self.data_lock = asyncio.Lock() + # Initialize aiohttp session and connector to None, will be set up in start_streaming + self.session: Optional[aiohttp.ClientSession] = None + self.connector: Optional[aiohttp.TCPConnector] = None + self.rest_session: Optional[aiohttp.ClientSession] = None # Added for explicit None initialization + # Create REST API session # Fix for Windows aiodns issue - use ThreadedResolver instead connector = aiohttp.TCPConnector( @@ -286,64 +291,62 @@ class MultiExchangeCOBProvider: return configs async def start_streaming(self): - """Start streaming from all configured exchanges""" - if self.is_streaming: - logger.warning("COB streaming already active") - return - - logger.info("Starting Multi-Exchange COB streaming") + """Start real-time order book streaming from all configured exchanges""" + logger.info(f"Starting COB streaming for symbols: {self.symbols}") self.is_streaming = True - # Start streaming tasks for each exchange and symbol + # Setup aiohttp session here, within the async context + await self._setup_http_session() + + # Start WebSocket connections for each active exchange and symbol tasks = [] - - for exchange_name in self.active_exchanges: - for symbol in self.symbols: - # WebSocket task for real-time top 20 levels - task = asyncio.create_task( - self._stream_exchange_orderbook(exchange_name, symbol) - ) - tasks.append(task) - - # REST API task for deep order book snapshots - deep_task = asyncio.create_task( - self._stream_deep_orderbook(exchange_name, symbol) - ) - tasks.append(deep_task) - - # Trade stream task for SVP - if exchange_name == 'binance': - trade_task = asyncio.create_task( - self._stream_binance_trades(symbol) - ) - tasks.append(trade_task) - - # Start consolidation and analysis tasks - tasks.extend([ - asyncio.create_task(self._continuous_consolidation()), - asyncio.create_task(self._continuous_bucket_updates()) - ]) - - # Wait for all tasks - try: - await asyncio.gather(*tasks) - except Exception as e: - logger.error(f"Error in streaming tasks: {e}") - finally: - self.is_streaming = False + for symbol in self.symbols: + for exchange_name, config in self.exchange_configs.items(): + if config.enabled and exchange_name in self.active_exchanges: + # Start WebSocket stream + tasks.append(self._stream_exchange_orderbook(exchange_name, symbol)) + + # Start deep order book (REST API) stream + tasks.append(self._stream_deep_orderbook(exchange_name, symbol)) + + # Start trade stream (for SVP) + if exchange_name == 'binance': # Only Binance for now + tasks.append(self._stream_binance_trades(symbol)) + + # Start continuous consolidation and bucket updates + tasks.append(self._continuous_consolidation()) + tasks.append(self._continuous_bucket_updates()) + + logger.info(f"Starting {len(tasks)} COB streaming tasks") + await asyncio.gather(*tasks) + + async def _setup_http_session(self): + """Setup aiohttp session and connector""" + self.connector = aiohttp.TCPConnector( + resolver=aiohttp.ThreadedResolver() # This is now created inside async function + ) + self.session = aiohttp.ClientSession(connector=self.connector) + self.rest_session = aiohttp.ClientSession(connector=self.connector) # Moved here from __init__ + logger.info("aiohttp session and connector setup completed") async def stop_streaming(self): - """Stop streaming from all exchanges""" - logger.info("Stopping Multi-Exchange COB streaming") + """Stop real-time order book streaming and close sessions""" + logger.info("Stopping COB Integration") self.is_streaming = False - - # Close REST API session - if self.rest_session: + + if self.session and not self.session.closed: + await self.session.close() + logger.info("aiohttp session closed") + + if self.rest_session and not self.rest_session.closed: await self.rest_session.close() - self.rest_session = None - - # Wait a bit for tasks to stop gracefully - await asyncio.sleep(1) + logger.info("aiohttp REST session closed") + + if self.connector and not self.connector.closed: + await self.connector.close() + logger.info("aiohttp connector closed") + + logger.info("COB Integration stopped") async def _stream_deep_orderbook(self, exchange_name: str, symbol: str): """Fetch deep order book data via REST API periodically""" @@ -1086,12 +1089,12 @@ class MultiExchangeCOBProvider: # Public interface methods - def subscribe_to_cob_updates(self, callback: Callable[[str, COBSnapshot], None]): + def subscribe_to_cob_updates(self, callback: Callable[[str, COBSnapshot], Awaitable[None]]): """Subscribe to consolidated order book updates""" self.cob_update_callbacks.append(callback) logger.info(f"Added COB update callback: {len(self.cob_update_callbacks)} total") - def subscribe_to_bucket_updates(self, callback: Callable[[str, Dict], None]): + def subscribe_to_bucket_updates(self, callback: Callable[[str, Dict], Awaitable[None]]): """Subscribe to price bucket updates""" self.bucket_update_callbacks.append(callback) logger.info(f"Added bucket update callback: {len(self.bucket_update_callbacks)} total") diff --git a/core/orchestrator.py b/core/orchestrator.py index 71b7566..0cb2eb5 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -386,7 +386,7 @@ class TradingOrchestrator: # Import COB integration directly (same as working dashboard) from core.cob_integration import COBIntegration - # Initialize COB integration with our symbols + # Initialize COB integration with our symbols (but don't start it yet) self.cob_integration = COBIntegration(symbols=self.symbols) # Register callbacks to receive real-time COB data @@ -440,13 +440,21 @@ class TradingOrchestrator: except Exception as e: logger.error(f"Error initializing COB integration: {e}") - self.cob_integration = None + self.cob_integration = None # Ensure it's None if init fails logger.info("COB integration will be disabled - models will use basic price data") async def start_cob_integration(self): """Start COB integration with matrix data collection""" try: - if self.cob_integration: + if not self.cob_integration: + logger.info("COB integration not initialized yet, creating instance.") + from core.cob_integration import COBIntegration + self.cob_integration = COBIntegration(symbols=self.symbols) + # Re-register callbacks if COBIntegration was just created + self.cob_integration.add_cnn_callback(self._on_cob_cnn_features) + self.cob_integration.add_dqn_callback(self._on_cob_dqn_features) + self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data) + logger.info("Starting COB integration with 5-minute matrix collection...") # Start COB integration in background thread @@ -480,12 +488,11 @@ class TradingOrchestrator: self._start_cob_matrix_worker() logger.info("COB Integration started - 5-minute data matrix streaming active") - else: - logger.warning("COB integration is None - cannot start") except Exception as e: logger.error(f"Error starting COB integration: {e}") self.cob_integration = None + logger.info("COB integration will be disabled - models will use basic price data") def _start_cob_matrix_worker(self): """Start background worker for COB matrix updates""" @@ -760,7 +767,18 @@ class TradingOrchestrator: def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]: """Get latest COB snapshot for a symbol""" - return self.latest_cob_data.get(symbol) + try: + # First try to get from COB integration (live data) + if self.cob_integration: + snapshot = self.cob_integration.get_cob_snapshot(symbol) + if snapshot: + return snapshot + + # Fallback to cached data if COB integration not available + return self.latest_cob_data.get(symbol) + except Exception as e: + logger.warning(f"Error getting COB snapshot for {symbol}: {e}") + return None def get_cob_feature_matrix(self, symbol: str, sequence_length: int = 60) -> Optional[np.ndarray]: """ @@ -1325,12 +1343,25 @@ class TradingOrchestrator: if state is None: return None - # Get RL agent's action and confidence - use the actual underlying model + # Get RL agent's action, confidence, and q_values from the underlying model if hasattr(model.model, 'act_with_confidence'): - action_idx, confidence = model.model.act_with_confidence(state) + # Call act_with_confidence and handle different return formats + result = model.model.act_with_confidence(state) + + if len(result) == 3: + # EnhancedCNN format: (action, confidence, q_values) + action_idx, confidence, raw_q_values = result + elif len(result) == 2: + # DQN format: (action, confidence) + action_idx, confidence = result + raw_q_values = None + else: + logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values") + return None elif hasattr(model.model, 'act'): action_idx = model.model.act(state, explore=False) confidence = 0.7 # Default confidence for basic act method + raw_q_values = None # No raw q_values from simple act else: logger.error(f"RL model {model.name} has no act method") return None @@ -1338,11 +1369,19 @@ class TradingOrchestrator: action_names = ['SELL', 'HOLD', 'BUY'] action = action_names[action_idx] + # Convert raw_q_values to list if they are a tensor + q_values_for_capture = None + if raw_q_values is not None and hasattr(raw_q_values, 'tolist'): + q_values_for_capture = raw_q_values.tolist() + elif raw_q_values is not None and isinstance(raw_q_values, list): + q_values_for_capture = raw_q_values + # Create prediction object prediction = Prediction( action=action, confidence=float(confidence), - probabilities={action: float(confidence), 'HOLD': 1.0 - float(confidence)}, + # Use actual q_values if available, otherwise default probabilities + probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))}, timeframe='mixed', # RL uses mixed timeframes timestamp=datetime.now(), model_name=model.name, @@ -1352,17 +1391,9 @@ class TradingOrchestrator: # Capture DQN prediction for dashboard visualization current_price = self._get_current_price(symbol) if current_price: - # Get Q-values if available - q_values = [0.33, 0.33, 0.34] # Default - if hasattr(model, 'get_q_values'): - try: - q_values = model.get_q_values(state) - if hasattr(q_values, 'tolist'): - q_values = q_values.tolist() - except: - pass - - self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values) + # Only pass q_values if they exist, otherwise pass empty list + q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else [] + self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass) return prediction @@ -2434,11 +2465,11 @@ class TradingOrchestrator: self.decision_fusion_network = DecisionFusionNet() logger.info("Decision fusion network initialized") - + except Exception as e: logger.warning(f"Decision fusion initialization failed: {e}") self.decision_fusion_enabled = False - + def _initialize_enhanced_training_system(self): """Initialize the enhanced real-time training system""" try: @@ -2599,7 +2630,7 @@ class TradingOrchestrator: if self.enhanced_training_system: self.enhanced_training_system.dashboard = dashboard logger.info("Dashboard reference set for enhanced training system") - + except Exception as e: logger.error(f"Error setting training dashboard: {e}") diff --git a/core/training_integration.py b/core/training_integration.py index babbea5..a6d7a3d 100644 --- a/core/training_integration.py +++ b/core/training_integration.py @@ -13,6 +13,9 @@ import logging from datetime import datetime from typing import Dict, List, Any, Optional import numpy as np +from utils.reward_calculator import RewardCalculator +import threading +import time logger = logging.getLogger(__name__) @@ -21,8 +24,16 @@ class TrainingIntegration: def __init__(self, orchestrator=None): self.orchestrator = orchestrator + self.reward_calculator = RewardCalculator() self.training_sessions = {} self.min_confidence_threshold = 0.15 # Lowered from 0.3 for more aggressive training + self.training_active = False + self.trainer_thread = None + self.stop_event = threading.Event() + self.training_lock = threading.Lock() + self.last_training_time = 0.0 if orchestrator is None else time.time() + self.training_interval = 300 # 5 minutes between training sessions + self.min_data_points = 100 # Minimum data points required to trigger training logger.info("TrainingIntegration initialized") @@ -347,46 +358,32 @@ class TrainingIntegration: return False def get_training_status(self) -> Dict[str, Any]: - """Get current training integration status""" + """Get current training status""" try: status = { - 'orchestrator_available': self.orchestrator is not None, - 'training_sessions': len(self.training_sessions), - 'last_update': datetime.now().isoformat() + 'active': self.training_active, + 'last_training_time': self.last_training_time, + 'training_sessions': self.training_sessions if self.training_sessions else {} } - - if self.orchestrator: - status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None - status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None - status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None - return status - except Exception as e: logger.error(f"Error getting training status: {e}") - return {'error': str(e)} + return {} def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str: """Start a new training session""" try: session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - - session_data = { - 'session_id': session_id, - 'session_name': session_name, - 'start_time': datetime.now().isoformat(), - 'config': config or {}, + self.training_sessions[session_id] = { + 'name': session_name, + 'start_time': datetime.now(), + 'config': config if config else {}, 'trades_processed': 0, - 'successful_trainings': 0, - 'failed_trainings': 0 + 'training_attempts': 0, + 'successful_trainings': 0 } - - self.training_sessions[session_id] = session_data - logger.info(f"Started training session: {session_id}") - return session_id - except Exception as e: logger.error(f"Error starting training session: {e}") return "" diff --git a/debug/debug_callback_simple.py b/debug/debug_callback_simple.py deleted file mode 100644 index 582fa26..0000000 --- a/debug/debug_callback_simple.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple callback debug script to see exact error -""" - -import requests -import json - -def test_simple_callback(): - """Test a simple callback to see the exact error""" - try: - # Test the simplest possible callback - callback_data = { - "output": "current-balance.children", - "inputs": [ - { - "id": "ultra-fast-interval", - "property": "n_intervals", - "value": 1 - } - ] - } - - print("Sending callback request...") - response = requests.post( - 'http://127.0.0.1:8051/_dash-update-component', - json=callback_data, - timeout=15, - headers={'Content-Type': 'application/json'} - ) - - print(f"Status Code: {response.status_code}") - print(f"Response Headers: {dict(response.headers)}") - print(f"Response Text (first 1000 chars):") - print(response.text[:1000]) - print("=" * 50) - - if response.status_code == 500: - # Try to extract error from HTML - if "Traceback" in response.text: - lines = response.text.split('\n') - for i, line in enumerate(lines): - if "Traceback" in line: - # Print next 20 lines for error details - for j in range(i, min(i+20, len(lines))): - print(lines[j]) - break - - except Exception as e: - print(f"Request failed: {e}") - -if __name__ == "__main__": - test_simple_callback() \ No newline at end of file diff --git a/debug/debug_dashboard.py b/debug/debug_dashboard.py deleted file mode 100644 index f269a54..0000000 --- a/debug/debug_dashboard.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug Dashboard - Minimal version to test callback functionality -""" - -import logging -import sys -from pathlib import Path -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -import dash -from dash import dcc, html, Input, Output -import plotly.graph_objects as go - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def create_debug_dashboard(): - """Create minimal debug dashboard""" - - app = dash.Dash(__name__) - - app.layout = html.Div([ - html.H1("๐Ÿ”ง Debug Dashboard - Callback Test", className="text-center"), - html.Div([ - html.H3(id="debug-time", className="text-center"), - html.H4(id="debug-counter", className="text-center"), - html.P(id="debug-status", className="text-center"), - dcc.Graph(id="debug-chart") - ]), - dcc.Interval( - id='debug-interval', - interval=2000, # 2 seconds - n_intervals=0 - ) - ]) - - @app.callback( - [ - Output('debug-time', 'children'), - Output('debug-counter', 'children'), - Output('debug-status', 'children'), - Output('debug-chart', 'figure') - ], - [Input('debug-interval', 'n_intervals')] - ) - def update_debug_dashboard(n_intervals): - """Debug callback function""" - try: - logger.info(f"๐Ÿ”ง DEBUG: Callback triggered, interval: {n_intervals}") - - current_time = datetime.now().strftime("%H:%M:%S") - counter = f"Updates: {n_intervals}" - status = f"Callback working! Last update: {current_time}" - - # Create simple test chart - fig = go.Figure() - fig.add_trace(go.Scatter( - x=list(range(max(0, n_intervals-10), n_intervals + 1)), - y=[i**2 for i in range(max(0, n_intervals-10), n_intervals + 1)], - mode='lines+markers', - name='Debug Data', - line=dict(color='#00ff88') - )) - fig.update_layout( - title=f"Debug Chart - Update #{n_intervals}", - template="plotly_dark", - paper_bgcolor='#1e1e1e', - plot_bgcolor='#1e1e1e' - ) - - logger.info(f"โœ… DEBUG: Returning data - time={current_time}, counter={counter}") - - return current_time, counter, status, fig - - except Exception as e: - logger.error(f"โŒ DEBUG: Error in callback: {e}") - import traceback - logger.error(f"Traceback: {traceback.format_exc()}") - return "Error", "Error", "Callback failed", {} - - return app - -def main(): - """Run the debug dashboard""" - logger.info("๐Ÿ”ง Starting debug dashboard...") - - try: - app = create_debug_dashboard() - logger.info("โœ… Debug dashboard created") - - logger.info("๐Ÿš€ Starting debug dashboard on http://127.0.0.1:8053") - logger.info("This will test if Dash callbacks work at all") - logger.info("Press Ctrl+C to stop") - - app.run(host='127.0.0.1', port=8053, debug=True) - - except KeyboardInterrupt: - logger.info("Debug dashboard stopped by user") - except Exception as e: - logger.error(f"โŒ Error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/debug/debug_dashboard_500.py b/debug/debug_dashboard_500.py deleted file mode 100644 index 0ad26a7..0000000 --- a/debug/debug_dashboard_500.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug Dashboard - Enhanced error logging to identify 500 errors -""" - -import logging -import sys -import traceback -from pathlib import Path -from datetime import datetime -import pandas as pd - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -import dash -from dash import dcc, html, Input, Output -import plotly.graph_objects as go - -from core.config import setup_logging -from core.data_provider import DataProvider - -# Setup logging without emojis -logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(sys.stdout), - logging.FileHandler('debug_dashboard.log') - ] -) -logger = logging.getLogger(__name__) - -class DebugDashboard: - """Debug dashboard with enhanced error logging""" - - def __init__(self): - logger.info("Initializing debug dashboard...") - - try: - self.data_provider = DataProvider() - logger.info("Data provider initialized successfully") - except Exception as e: - logger.error(f"Error initializing data provider: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") - raise - - # Initialize app - self.app = dash.Dash(__name__) - logger.info("Dash app created") - - # Setup layout and callbacks - try: - self._setup_layout() - logger.info("Layout setup completed") - except Exception as e: - logger.error(f"Error setting up layout: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") - raise - - try: - self._setup_callbacks() - logger.info("Callbacks setup completed") - except Exception as e: - logger.error(f"Error setting up callbacks: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") - raise - - logger.info("Debug dashboard initialized successfully") - - def _setup_layout(self): - """Setup minimal layout for debugging""" - logger.info("Setting up layout...") - - self.app.layout = html.Div([ - html.H1("Debug Dashboard - 500 Error Investigation", className="text-center"), - - # Simple metrics - html.Div([ - html.Div([ - html.H3(id="current-time", children="Loading..."), - html.P("Current Time") - ], className="col-md-3"), - - html.Div([ - html.H3(id="update-counter", children="0"), - html.P("Update Count") - ], className="col-md-3"), - - html.Div([ - html.H3(id="status", children="Starting..."), - html.P("Status") - ], className="col-md-3"), - - html.Div([ - html.H3(id="error-count", children="0"), - html.P("Error Count") - ], className="col-md-3") - ], className="row mb-4"), - - # Error log - html.Div([ - html.H4("Error Log"), - html.Div(id="error-log", children="No errors yet...") - ], className="mb-4"), - - # Simple chart - html.Div([ - dcc.Graph(id="debug-chart", style={"height": "300px"}) - ]), - - # Interval component - dcc.Interval( - id='debug-interval', - interval=2000, # 2 seconds for easier debugging - n_intervals=0 - ) - ], className="container-fluid") - - logger.info("Layout setup completed") - - def _setup_callbacks(self): - """Setup callbacks with extensive error handling""" - logger.info("Setting up callbacks...") - - # Store reference to self - dashboard_instance = self - error_count = 0 - error_log = [] - - @self.app.callback( - [ - Output('current-time', 'children'), - Output('update-counter', 'children'), - Output('status', 'children'), - Output('error-count', 'children'), - Output('error-log', 'children'), - Output('debug-chart', 'figure') - ], - [Input('debug-interval', 'n_intervals')] - ) - def update_debug_dashboard(n_intervals): - """Debug callback with extensive error handling""" - nonlocal error_count, error_log - - logger.info(f"=== CALLBACK START - Interval {n_intervals} ===") - - try: - # Current time - current_time = datetime.now().strftime("%H:%M:%S") - logger.info(f"Current time: {current_time}") - - # Update counter - counter = f"Updates: {n_intervals}" - logger.info(f"Counter: {counter}") - - # Status - status = "Running OK" if n_intervals > 0 else "Starting" - logger.info(f"Status: {status}") - - # Error count - error_count_str = f"Errors: {error_count}" - logger.info(f"Error count: {error_count_str}") - - # Error log display - if error_log: - error_display = html.Div([ - html.P(f"Error {i+1}: {error}", className="text-danger") - for i, error in enumerate(error_log[-5:]) # Show last 5 errors - ]) - else: - error_display = "No errors yet..." - - # Create chart - logger.info("Creating chart...") - try: - chart = dashboard_instance._create_debug_chart(n_intervals) - logger.info("Chart created successfully") - except Exception as chart_error: - logger.error(f"Error creating chart: {chart_error}") - logger.error(f"Chart error traceback: {traceback.format_exc()}") - error_count += 1 - error_log.append(f"Chart error: {str(chart_error)}") - chart = dashboard_instance._create_error_chart(str(chart_error)) - - logger.info("=== CALLBACK SUCCESS ===") - - return current_time, counter, status, error_count_str, error_display, chart - - except Exception as e: - error_count += 1 - error_msg = f"Callback error: {str(e)}" - error_log.append(error_msg) - - logger.error(f"=== CALLBACK ERROR ===") - logger.error(f"Error: {e}") - logger.error(f"Error type: {type(e)}") - logger.error(f"Traceback: {traceback.format_exc()}") - - # Return safe fallback values - error_chart = dashboard_instance._create_error_chart(str(e)) - error_display = html.Div([ - html.P(f"CALLBACK ERROR: {str(e)}", className="text-danger"), - html.P(f"Error count: {error_count}", className="text-warning") - ]) - - return "ERROR", f"Errors: {error_count}", "FAILED", f"Errors: {error_count}", error_display, error_chart - - logger.info("Callbacks setup completed") - - def _create_debug_chart(self, n_intervals): - """Create a simple debug chart""" - logger.info(f"Creating debug chart for interval {n_intervals}") - - try: - # Try to get real data every 5 intervals - if n_intervals % 5 == 0: - logger.info("Attempting to fetch real data...") - try: - df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=20) - if df is not None and not df.empty: - logger.info(f"Fetched {len(df)} real candles") - self.chart_data = df - else: - logger.warning("No real data returned") - except Exception as data_error: - logger.error(f"Error fetching real data: {data_error}") - logger.error(f"Data fetch traceback: {traceback.format_exc()}") - - # Create chart - fig = go.Figure() - - if hasattr(self, 'chart_data') and not self.chart_data.empty: - logger.info("Using real data for chart") - fig.add_trace(go.Scatter( - x=self.chart_data['timestamp'], - y=self.chart_data['close'], - mode='lines', - name='ETH/USDT Real', - line=dict(color='#00ff88') - )) - title = f"ETH/USDT Real Data - Update #{n_intervals}" - else: - logger.info("Using mock data for chart") - # Simple mock data - x_data = list(range(max(0, n_intervals-10), n_intervals + 1)) - y_data = [3500 + 50 * (i % 5) for i in x_data] - - fig.add_trace(go.Scatter( - x=x_data, - y=y_data, - mode='lines', - name='Mock Data', - line=dict(color='#ff8800') - )) - title = f"Mock Data - Update #{n_intervals}" - - fig.update_layout( - title=title, - template="plotly_dark", - paper_bgcolor='#1e1e1e', - plot_bgcolor='#1e1e1e', - showlegend=False, - height=300 - ) - - logger.info("Chart created successfully") - return fig - - except Exception as e: - logger.error(f"Error in _create_debug_chart: {e}") - logger.error(f"Chart creation traceback: {traceback.format_exc()}") - raise - - def _create_error_chart(self, error_msg): - """Create error chart""" - logger.info(f"Creating error chart: {error_msg}") - - fig = go.Figure() - fig.add_annotation( - text=f"Chart Error: {error_msg}", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=14, color="#ff4444") - ) - fig.update_layout( - template="plotly_dark", - paper_bgcolor='#1e1e1e', - plot_bgcolor='#1e1e1e', - height=300 - ) - return fig - - def run(self, host='127.0.0.1', port=8053, debug=True): - """Run the debug dashboard""" - logger.info(f"Starting debug dashboard at http://{host}:{port}") - logger.info("This dashboard has enhanced error logging to identify 500 errors") - - try: - self.app.run(host=host, port=port, debug=debug) - except Exception as e: - logger.error(f"Error running dashboard: {e}") - logger.error(f"Run error traceback: {traceback.format_exc()}") - raise - -def main(): - """Main function""" - logger.info("Starting debug dashboard main...") - - try: - dashboard = DebugDashboard() - dashboard.run() - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - except Exception as e: - logger.error(f"Fatal error: {e}") - logger.error(f"Fatal traceback: {traceback.format_exc()}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/debug/debug_dashboard_issue.py b/debug/debug_dashboard_issue.py deleted file mode 100644 index aad2700..0000000 --- a/debug/debug_dashboard_issue.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug Dashboard Data Flow - -Check if the dashboard is receiving data and updating properly. -""" - -import sys -import logging -import time -import requests -import json -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import get_config, setup_logging -from core.data_provider import DataProvider - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -def test_data_provider(): - """Test if data provider is working""" - logger.info("=== TESTING DATA PROVIDER ===") - - try: - # Test data provider - data_provider = DataProvider() - - # Test current price - logger.info("Testing current price retrieval...") - current_price = data_provider.get_current_price('ETH/USDT') - logger.info(f"Current ETH/USDT price: ${current_price}") - - # Test historical data - logger.info("Testing historical data retrieval...") - df = data_provider.get_historical_data('ETH/USDT', '1m', limit=5, refresh=True) - if df is not None and not df.empty: - logger.info(f"Historical data: {len(df)} rows") - logger.info(f"Latest price: ${df['close'].iloc[-1]:.2f}") - logger.info(f"Latest timestamp: {df.index[-1]}") - else: - logger.error("No historical data available!") - - return True - - except Exception as e: - logger.error(f"Data provider test failed: {e}") - return False - -def test_dashboard_api(): - """Test if dashboard API is responding""" - logger.info("=== TESTING DASHBOARD API ===") - - try: - # Test main dashboard page - response = requests.get("http://127.0.0.1:8050", timeout=5) - logger.info(f"Dashboard main page status: {response.status_code}") - - if response.status_code == 200: - logger.info("Dashboard is responding") - - # Check if there are any JavaScript errors in the page - content = response.text - if 'error' in content.lower(): - logger.warning("Possible errors found in dashboard HTML") - - return True - else: - logger.error(f"Dashboard returned status {response.status_code}") - return False - - except Exception as e: - logger.error(f"Dashboard API test failed: {e}") - return False - -def test_dashboard_callbacks(): - """Test dashboard callback updates""" - logger.info("=== TESTING DASHBOARD CALLBACKS ===") - - try: - # Test the callback endpoint (this would need to be exposed) - # For now, just check if the dashboard is serving content - - # Wait a bit and check again - time.sleep(2) - - response = requests.get("http://127.0.0.1:8050", timeout=5) - if response.status_code == 200: - logger.info("Dashboard callbacks appear to be working") - return True - else: - logger.error("Dashboard callbacks may be stuck") - return False - - except Exception as e: - logger.error(f"Dashboard callback test failed: {e}") - return False - -def main(): - """Run all diagnostic tests""" - logger.info("DASHBOARD DIAGNOSTIC TOOL") - logger.info("=" * 50) - - results = { - 'data_provider': test_data_provider(), - 'dashboard_api': test_dashboard_api(), - 'dashboard_callbacks': test_dashboard_callbacks() - } - - logger.info("=" * 50) - logger.info("DIAGNOSTIC RESULTS:") - - for test_name, result in results.items(): - status = "PASS" if result else "FAIL" - logger.info(f" {test_name}: {status}") - - if all(results.values()): - logger.info("All tests passed - issue may be browser-side") - logger.info("Try refreshing the dashboard at http://127.0.0.1:8050") - else: - logger.error("Issues detected - check logs above") - logger.info("Recommendations:") - - if not results['data_provider']: - logger.info(" - Check internet connection") - logger.info(" - Verify Binance API is accessible") - - if not results['dashboard_api']: - logger.info(" - Restart the dashboard") - logger.info(" - Check if port 8050 is blocked") - - if not results['dashboard_callbacks']: - logger.info(" - Dashboard may be frozen") - logger.info(" - Consider restarting") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/debug/debug_mexc_auth.py b/debug/debug_mexc_auth.py deleted file mode 100644 index 6e4b34a..0000000 --- a/debug/debug_mexc_auth.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug script for MEXC API authentication -""" - -import os -import hmac -import hashlib -import time -import requests -from urllib.parse import urlencode -from dotenv import load_dotenv - -# Load environment variables -load_dotenv() - -def debug_mexc_auth(): - """Debug MEXC API authentication step by step""" - - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - print("="*60) - print("MEXC API AUTHENTICATION DEBUG") - print("="*60) - - print(f"API Key: {api_key}") - print(f"API Secret: {api_secret[:10]}...{api_secret[-10:]}") - print() - - # Test 1: Public API (no auth required) - print("1. Testing Public API (ping)...") - try: - response = requests.get("https://api.mexc.com/api/v3/ping") - print(f" Status: {response.status_code}") - print(f" Response: {response.json()}") - print(" โœ… Public API works") - except Exception as e: - print(f" โŒ Public API failed: {e}") - return - print() - - # Test 2: Get server time - print("2. Testing Server Time...") - try: - response = requests.get("https://api.mexc.com/api/v3/time") - server_time_data = response.json() - server_time = server_time_data['serverTime'] - print(f" Server Time: {server_time}") - print(" โœ… Server time retrieved") - except Exception as e: - print(f" โŒ Server time failed: {e}") - return - print() - - # Test 3: Manual signature generation and account request - print("3. Testing Authentication (manual signature)...") - - # Get server time for accurate timestamp - try: - server_response = requests.get("https://api.mexc.com/api/v3/time") - server_time = server_response.json()['serverTime'] - print(f" Using Server Time: {server_time}") - except: - server_time = int(time.time() * 1000) - print(f" Using Local Time: {server_time}") - - # Parameters for account endpoint - params = { - 'timestamp': server_time, - 'recvWindow': 10000 # Increased receive window - } - - print(f" Timestamp: {server_time}") - print(f" Params: {params}") - - # Generate signature manually - # According to MEXC documentation, parameters should be sorted - sorted_params = sorted(params.items()) - query_string = urlencode(sorted_params) - print(f" Query String: {query_string}") - - # MEXC documentation shows signature in lowercase - signature = hmac.new( - api_secret.encode('utf-8'), - query_string.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - - print(f" Generated Signature (hex): {signature}") - print(f" API Secret used: {api_secret[:5]}...{api_secret[-5:]}") - print(f" Query string length: {len(query_string)}") - print(f" Signature length: {len(signature)}") - - print(f" Generated Signature: {signature}") - - # Add signature to params - params['signature'] = signature - - # Make the request - headers = { - 'X-MEXC-APIKEY': api_key - } - - print(f" Headers: {headers}") - print(f" Final Params: {params}") - - try: - response = requests.get( - "https://api.mexc.com/api/v3/account", - params=params, - headers=headers - ) - - print(f" Status Code: {response.status_code}") - print(f" Response Headers: {dict(response.headers)}") - - if response.status_code == 200: - account_data = response.json() - print(f" โœ… Authentication successful!") - print(f" Account Type: {account_data.get('accountType', 'N/A')}") - print(f" Can Trade: {account_data.get('canTrade', 'N/A')}") - print(f" Can Withdraw: {account_data.get('canWithdraw', 'N/A')}") - print(f" Can Deposit: {account_data.get('canDeposit', 'N/A')}") - print(f" Number of balances: {len(account_data.get('balances', []))}") - - # Show USDT balance - for balance in account_data.get('balances', []): - if balance['asset'] == 'USDT': - print(f" ๐Ÿ’ฐ USDT Balance: {balance['free']} (locked: {balance['locked']})") - break - - else: - print(f" โŒ Authentication failed!") - print(f" Response: {response.text}") - - # Try to parse error - try: - error_data = response.json() - print(f" Error Code: {error_data.get('code', 'N/A')}") - print(f" Error Message: {error_data.get('msg', 'N/A')}") - except: - pass - - except Exception as e: - print(f" โŒ Request failed: {e}") - -if __name__ == "__main__": - debug_mexc_auth() \ No newline at end of file diff --git a/debug/debug_orchestrator_methods.py b/debug/debug_orchestrator_methods.py deleted file mode 100644 index 26a9394..0000000 --- a/debug/debug_orchestrator_methods.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug Orchestrator Methods - Test enhanced orchestrator method availability -""" - -import sys -from pathlib import Path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -def debug_orchestrator_methods(): - """Debug orchestrator method availability""" - print("=== DEBUGGING ORCHESTRATOR METHODS ===") - - try: - # Import the classes we need - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - print("โœ“ Imports successful") - - # Create basic data provider (no async) - dp = DataProvider() - print("โœ“ DataProvider created") - - # Create basic orchestrator first - basic_orch = TradingOrchestrator(dp) - print("โœ“ Basic TradingOrchestrator created") - - # Test basic orchestrator methods - basic_methods = ['calculate_enhanced_pivot_reward', 'build_comprehensive_rl_state'] - print("\nBasic TradingOrchestrator methods:") - for method in basic_methods: - available = hasattr(basic_orch, method) - print(f" {method}: {'โœ“' if available else 'โœ—'}") - - # Now test Enhanced orchestrator class methods (not instantiated) - print("\nEnhancedTradingOrchestrator class methods:") - for method in basic_methods: - available = hasattr(EnhancedTradingOrchestrator, method) - print(f" {method}: {'โœ“' if available else 'โœ—'}") - - # Check what methods are actually in the EnhancedTradingOrchestrator - print(f"\nEnhancedTradingOrchestrator all methods:") - all_methods = [m for m in dir(EnhancedTradingOrchestrator) if not m.startswith('_')] - enhanced_methods = [m for m in all_methods if 'enhanced' in m.lower() or 'comprehensive' in m.lower() or 'pivot' in m.lower()] - - print(f" Total methods: {len(all_methods)}") - print(f" Enhanced/comprehensive/pivot methods: {enhanced_methods}") - - # Test specific methods we're looking for - target_methods = [ - 'calculate_enhanced_pivot_reward', - 'build_comprehensive_rl_state', - '_get_symbol_correlation' - ] - - print(f"\nTarget methods in EnhancedTradingOrchestrator:") - for method in target_methods: - if hasattr(EnhancedTradingOrchestrator, method): - print(f" โœ“ {method}: Found") - else: - print(f" โœ— {method}: Missing") - # Check if it's a similar name - similar = [m for m in all_methods if method.replace('_', '').lower() in m.replace('_', '').lower()] - if similar: - print(f" Similar: {similar}") - - print("\n=== DEBUG COMPLETE ===") - - except Exception as e: - print(f"โœ— Debug failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - debug_orchestrator_methods() \ No newline at end of file diff --git a/debug/debug_simple_callback.py b/debug/debug_simple_callback.py deleted file mode 100644 index 37a079f..0000000 --- a/debug/debug_simple_callback.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug simple callback to see exact error -""" - -import requests -import json - -def debug_simple_callback(): - """Debug the simple callback""" - try: - callback_data = { - "output": "test-output.children", - "inputs": [ - { - "id": "test-interval", - "property": "n_intervals", - "value": 1 - } - ] - } - - print("Testing simple dashboard callback...") - response = requests.post( - 'http://127.0.0.1:8052/_dash-update-component', - json=callback_data, - timeout=15, - headers={'Content-Type': 'application/json'} - ) - - print(f"Status Code: {response.status_code}") - - if response.status_code == 500: - print("Error response:") - print(response.text) - else: - print("Success response:") - print(response.text[:500]) - - except Exception as e: - print(f"Request failed: {e}") - -if __name__ == "__main__": - debug_simple_callback() \ No newline at end of file diff --git a/debug/debug_trading_activity.py b/debug/debug_trading_activity.py deleted file mode 100644 index d2f2e26..0000000 --- a/debug/debug_trading_activity.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python3 -""" -Trading Activity Diagnostic Script -Debug why no trades are happening after 6 hours -""" - -import logging -import asyncio -from datetime import datetime, timedelta -import pandas as pd -import numpy as np - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -async def diagnose_trading_system(): - """Comprehensive diagnosis of trading system""" - logger.info("=== TRADING SYSTEM DIAGNOSTIC ===") - - try: - # Import core components - from core.config import get_config - from core.data_provider import DataProvider - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - - # Initialize components - config = get_config() - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator( - data_provider=data_provider, - symbols=['ETH/USDT', 'BTC/USDT'], - enhanced_rl_training=True - ) - - logger.info("โœ… Components initialized successfully") - - # 1. Check data availability - logger.info("\n=== DATA AVAILABILITY CHECK ===") - for symbol in ['ETH/USDT', 'BTC/USDT']: - for timeframe in ['1m', '5m', '1h']: - try: - data = data_provider.get_historical_data(symbol, timeframe, limit=10) - if data is not None and not data.empty: - logger.info(f"โœ… {symbol} {timeframe}: {len(data)} bars available") - logger.info(f" Last price: ${data['close'].iloc[-1]:.2f}") - else: - logger.error(f"โŒ {symbol} {timeframe}: NO DATA") - except Exception as e: - logger.error(f"โŒ {symbol} {timeframe}: ERROR - {e}") - - # 2. Check model status - logger.info("\n=== MODEL STATUS CHECK ===") - model_status = orchestrator.get_loaded_models_status() if hasattr(orchestrator, 'get_loaded_models_status') else {} - logger.info(f"Loaded models: {model_status}") - - # 3. Check confidence thresholds - logger.info("\n=== CONFIDENCE THRESHOLD CHECK ===") - logger.info(f"Entry threshold: {getattr(orchestrator, 'confidence_threshold_open', 'UNKNOWN')}") - logger.info(f"Exit threshold: {getattr(orchestrator, 'confidence_threshold_close', 'UNKNOWN')}") - logger.info(f"Config threshold: {config.orchestrator.get('confidence_threshold', 'UNKNOWN')}") - - # 4. Test decision making - logger.info("\n=== DECISION MAKING TEST ===") - try: - decisions = await orchestrator.make_coordinated_decisions() - logger.info(f"Generated {len(decisions)} decisions") - - for symbol, decision in decisions.items(): - if decision: - logger.info(f"โœ… {symbol}: {decision.action} " - f"(confidence: {decision.confidence:.3f}, " - f"price: ${decision.price:.2f})") - else: - logger.warning(f"โŒ {symbol}: No decision generated") - - except Exception as e: - logger.error(f"โŒ Decision making failed: {e}") - - # 5. Test cold start predictions - logger.info("\n=== COLD START PREDICTIONS TEST ===") - try: - await orchestrator.ensure_predictions_available() - logger.info("โœ… Cold start predictions system working") - except Exception as e: - logger.error(f"โŒ Cold start predictions failed: {e}") - - # 6. Check cross-asset signals - logger.info("\n=== CROSS-ASSET SIGNALS TEST ===") - try: - from core.unified_data_stream import UniversalDataStream - - # Create mock universal stream for testing - mock_stream = type('MockStream', (), {})() - mock_stream.get_latest_data = lambda symbol: {'price': 2500.0 if 'ETH' in symbol else 35000.0} - mock_stream.get_market_structure = lambda symbol: {'trend': 'NEUTRAL', 'strength': 0.5} - mock_stream.get_cob_data = lambda symbol: {'imbalance': 0.0, 'depth': 'BALANCED'} - - btc_analysis = await orchestrator._analyze_btc_price_action(mock_stream) - logger.info(f"BTC analysis result: {btc_analysis}") - - eth_decision = await orchestrator._make_eth_decision_from_btc_signals( - {'signal': 'NEUTRAL', 'strength': 0.5}, - {'signal': 'NEUTRAL', 'imbalance': 0.0} - ) - logger.info(f"ETH decision result: {eth_decision}") - - except Exception as e: - logger.error(f"โŒ Cross-asset signals failed: {e}") - - # 7. Simulate trade with lower thresholds - logger.info("\n=== SIMULATED TRADE TEST ===") - try: - # Create mock prediction with low confidence - from core.enhanced_orchestrator import EnhancedPrediction - - mock_prediction = EnhancedPrediction( - model_name="TEST", - timeframe="1m", - action="BUY", - confidence=0.30, # Lower confidence - overall_action="BUY", - overall_confidence=0.30, - timeframe_predictions=[], - reasoning="Test prediction" - ) - - # Test if this would generate a trade - current_price = 2500.0 - quantity = 0.01 - - logger.info(f"Mock prediction: {mock_prediction.action} " - f"(confidence: {mock_prediction.confidence:.3f})") - - if mock_prediction.confidence > 0.25: # Our new lower threshold - logger.info("โœ… Would generate trade with new threshold") - else: - logger.warning("โŒ Still below threshold") - - except Exception as e: - logger.error(f"โŒ Simulated trade test failed: {e}") - - # 8. Check RL reward functions - logger.info("\n=== RL REWARD FUNCTION TEST ===") - try: - # Test reward calculation - mock_trade = { - 'action': 'BUY', - 'confidence': 0.75, - 'price': 2500.0, - 'timestamp': datetime.now() - } - - mock_outcome = { - 'net_pnl': 25.0, # $25 profit - 'exit_price': 2525.0, - 'duration': timedelta(minutes=15) - } - - mock_market_data = { - 'volatility': 0.03, - 'order_flow_direction': 'bullish', - 'order_flow_strength': 0.8 - } - - if hasattr(orchestrator, 'calculate_enhanced_pivot_reward'): - reward = orchestrator.calculate_enhanced_pivot_reward( - mock_trade, mock_market_data, mock_outcome - ) - logger.info(f"โœ… RL reward for profitable trade: {reward:.3f}") - else: - logger.warning("โŒ Enhanced pivot reward function not available") - - except Exception as e: - logger.error(f"โŒ RL reward test failed: {e}") - - logger.info("\n=== DIAGNOSTIC COMPLETE ===") - logger.info("Check results above to identify trading bottlenecks") - - except Exception as e: - logger.error(f"Diagnostic failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - asyncio.run(diagnose_trading_system()) \ No newline at end of file diff --git a/enhanced_realtime_training.py b/enhanced_realtime_training.py index 66973c0..2bc70de 100644 --- a/enhanced_realtime_training.py +++ b/enhanced_realtime_training.py @@ -36,8 +36,8 @@ class EnhancedRealtimeTrainingSystem: self.training_config = { 'dqn_training_interval': 5, # Train DQN every 5 seconds 'cnn_training_interval': 10, # Train CNN every 10 seconds - 'batch_size': 64, # Larger batch size for stability - 'memory_size': 10000, # Larger memory for diversity + 'batch_size': 640, # Larger batch size for stability (increased 10x) + 'memory_size': 100000, # Larger memory for diversity (increased 10x) 'validation_interval': 60, # Validate every minute 'adaptation_threshold': 0.1, # Adapt if performance drops 10% 'min_training_samples': 100 # Minimum samples before training diff --git a/tests/test_binance_data.py b/tests/test_binance_data.py deleted file mode 100644 index 3399c32..0000000 --- a/tests/test_binance_data.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to check Binance data availability -""" - -import sys -import logging -from datetime import datetime - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_binance_data(): - """Test Binance data fetching""" - print("="*60) - print("BINANCE DATA TEST") - print("="*60) - - try: - print("1. Testing DataProvider import...") - from core.data_provider import DataProvider - print(" โœ… DataProvider imported successfully") - - print("\n2. Creating DataProvider instance...") - dp = DataProvider() - print(f" โœ… DataProvider created") - print(f" Symbols: {dp.symbols}") - print(f" Timeframes: {dp.timeframes}") - - print("\n3. Testing historical data fetch...") - try: - data = dp.get_historical_data('ETH/USDT', '1m', 10) - if data is not None: - print(f" โœ… Historical data fetched: {data.shape}") - print(f" Latest price: ${data['close'].iloc[-1]:.2f}") - print(f" Data range: {data.index[0]} to {data.index[-1]}") - else: - print(" โŒ No historical data returned") - except Exception as e: - print(f" โŒ Error fetching historical data: {e}") - - print("\n4. Testing current price...") - try: - price = dp.get_current_price('ETH/USDT') - if price: - print(f" โœ… Current price: ${price:.2f}") - else: - print(" โŒ No current price available") - except Exception as e: - print(f" โŒ Error getting current price: {e}") - - print("\n5. Testing real-time streaming setup...") - try: - # Check if streaming can be initialized - print(f" Streaming status: {dp.is_streaming}") - print(" โœ… Real-time streaming setup ready") - except Exception as e: - print(f" โŒ Real-time streaming error: {e}") - - except Exception as e: - print(f"โŒ Failed to import or create DataProvider: {e}") - import traceback - traceback.print_exc() - -def test_dashboard_connection(): - """Test if dashboard can connect to data""" - print("\n" + "="*60) - print("DASHBOARD CONNECTION TEST") - print("="*60) - - try: - print("1. Testing dashboard imports...") - from web.old_archived.scalping_dashboard import ScalpingDashboard - print(" โœ… ScalpingDashboard imported") - - print("\n2. Testing data provider connection...") - # Check if the dashboard can create a data provider - dashboard = ScalpingDashboard() - if hasattr(dashboard, 'data_provider'): - print(" โœ… Dashboard has data_provider") - print(f" Data provider symbols: {dashboard.data_provider.symbols}") - else: - print(" โŒ Dashboard missing data_provider") - - except Exception as e: - print(f"โŒ Dashboard connection error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_binance_data() - test_dashboard_connection() \ No newline at end of file diff --git a/tests/test_callback_registration.py b/tests/test_callback_registration.py deleted file mode 100644 index af52362..0000000 --- a/tests/test_callback_registration.py +++ /dev/null @@ -1,221 +0,0 @@ -#!/usr/bin/env python3 -""" -Test callback registration to identify the issue -""" - -import logging -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -import dash -from dash import dcc, html, Input, Output - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_simple_callback(): - """Test a simple callback registration""" - logger.info("Testing simple callback registration...") - - app = dash.Dash(__name__) - - app.layout = html.Div([ - html.H1("Callback Registration Test"), - html.Div(id="output", children="Initial"), - dcc.Interval(id="interval", interval=1000, n_intervals=0) - ]) - - @app.callback( - Output('output', 'children'), - Input('interval', 'n_intervals') - ) - def update_output(n_intervals): - logger.info(f"Callback triggered: {n_intervals}") - return f"Update #{n_intervals}" - - logger.info("Simple callback registered successfully") - - # Check if callback is in the callback map - logger.info(f"Callback map keys: {list(app.callback_map.keys())}") - - return app - -def test_complex_callback(): - """Test a complex callback like the dashboard""" - logger.info("Testing complex callback registration...") - - app = dash.Dash(__name__) - - app.layout = html.Div([ - html.H1("Complex Callback Test"), - html.Div(id="current-balance", children="$100.00"), - html.Div(id="session-duration", children="00:00:00"), - html.Div(id="status", children="Starting"), - dcc.Graph(id="chart"), - dcc.Interval(id="ultra-fast-interval", interval=1000, n_intervals=0) - ]) - - @app.callback( - [ - Output('current-balance', 'children'), - Output('session-duration', 'children'), - Output('status', 'children'), - Output('chart', 'figure') - ], - [Input('ultra-fast-interval', 'n_intervals')] - ) - def update_dashboard(n_intervals): - logger.info(f"Complex callback triggered: {n_intervals}") - - import plotly.graph_objects as go - fig = go.Figure() - fig.add_trace(go.Scatter(x=[1, 2, 3], y=[1, 2, 3], mode='lines')) - fig.update_layout(template="plotly_dark") - - return f"${100 + n_intervals:.2f}", f"00:00:{n_intervals:02d}", "Running", fig - - logger.info("Complex callback registered successfully") - - # Check if callback is in the callback map - logger.info(f"Callback map keys: {list(app.callback_map.keys())}") - - return app - -def test_dashboard_callback(): - """Test the exact dashboard callback structure""" - logger.info("Testing dashboard callback structure...") - - try: - from core.data_provider import DataProvider - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - - app = dash.Dash(__name__) - - # Minimal layout with dashboard elements - app.layout = html.Div([ - html.H1("Dashboard Callback Test"), - html.Div(id="current-balance", children="$100.00"), - html.Div(id="session-duration", children="00:00:00"), - html.Div(id="open-positions", children="0"), - html.Div(id="live-pnl", children="$0.00"), - html.Div(id="win-rate", children="0%"), - html.Div(id="total-trades", children="0"), - html.Div(id="last-action", children="WAITING"), - html.Div(id="eth-price", children="Loading..."), - html.Div(id="btc-price", children="Loading..."), - dcc.Graph(id="main-eth-1s-chart"), - dcc.Graph(id="eth-1m-chart"), - dcc.Graph(id="eth-1h-chart"), - dcc.Graph(id="eth-1d-chart"), - dcc.Graph(id="btc-1s-chart"), - html.Div(id="actions-log", children="No actions yet"), - html.Div(id="debug-status", children="Debug info"), - dcc.Interval(id="ultra-fast-interval", interval=1000, n_intervals=0) - ]) - - @app.callback( - [ - Output('current-balance', 'children'), - Output('session-duration', 'children'), - Output('open-positions', 'children'), - Output('live-pnl', 'children'), - Output('win-rate', 'children'), - Output('total-trades', 'children'), - Output('last-action', 'children'), - Output('eth-price', 'children'), - Output('btc-price', 'children'), - Output('main-eth-1s-chart', 'figure'), - Output('eth-1m-chart', 'figure'), - Output('eth-1h-chart', 'figure'), - Output('eth-1d-chart', 'figure'), - Output('btc-1s-chart', 'figure'), - Output('actions-log', 'children'), - Output('debug-status', 'children') - ], - [Input('ultra-fast-interval', 'n_intervals')] - ) - def update_dashboard_test(n_intervals): - logger.info(f"Dashboard callback triggered: {n_intervals}") - - import plotly.graph_objects as go - from datetime import datetime - - # Create empty figure - empty_fig = go.Figure() - empty_fig.update_layout(template="plotly_dark") - - debug_status = html.Div([ - html.P(f"Test Callback #{n_intervals} at {datetime.now().strftime('%H:%M:%S')}") - ]) - - return ( - f"${100 + n_intervals:.2f}", # current-balance - f"00:00:{n_intervals:02d}", # session-duration - "0", # open-positions - f"${n_intervals:+.2f}", # live-pnl - "75%", # win-rate - str(n_intervals), # total-trades - "TEST", # last-action - "$3500.00", # eth-price - "$65000.00", # btc-price - empty_fig, # main-eth-1s-chart - empty_fig, # eth-1m-chart - empty_fig, # eth-1h-chart - empty_fig, # eth-1d-chart - empty_fig, # btc-1s-chart - f"Test action #{n_intervals}", # actions-log - debug_status # debug-status - ) - - logger.info("Dashboard callback registered successfully") - logger.info(f"Callback map keys: {list(app.callback_map.keys())}") - - return app - - except Exception as e: - logger.error(f"Error testing dashboard callback: {e}") - import traceback - logger.error(f"Traceback: {traceback.format_exc()}") - return None - -def main(): - """Main test function""" - logger.info("Starting callback registration tests...") - - # Test 1: Simple callback - try: - simple_app = test_simple_callback() - logger.info("โœ… Simple callback test passed") - except Exception as e: - logger.error(f"โŒ Simple callback test failed: {e}") - - # Test 2: Complex callback - try: - complex_app = test_complex_callback() - logger.info("โœ… Complex callback test passed") - except Exception as e: - logger.error(f"โŒ Complex callback test failed: {e}") - - # Test 3: Dashboard callback - try: - dashboard_app = test_dashboard_callback() - if dashboard_app: - logger.info("โœ… Dashboard callback test passed") - - # Run the dashboard test - logger.info("Starting dashboard test server on port 8054...") - dashboard_app.run(host='127.0.0.1', port=8054, debug=True) - else: - logger.error("โŒ Dashboard callback test failed") - except Exception as e: - logger.error(f"โŒ Dashboard callback test failed: {e}") - import traceback - logger.error(f"Traceback: {traceback.format_exc()}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_callback_simple.py b/tests/test_callback_simple.py deleted file mode 100644 index 2ddf0ed..0000000 --- a/tests/test_callback_simple.py +++ /dev/null @@ -1,22 +0,0 @@ -import requests -import json - -def test_callback(): - try: - url = 'http://127.0.0.1:8051/_dash-update-component' - data = { - "output": "current-balance.children", - "inputs": [{"id": "ultra-fast-interval", "property": "n_intervals", "value": 1}], - "changedPropIds": ["ultra-fast-interval.n_intervals"], - "state": [] - } - - response = requests.post(url, json=data, timeout=10) - print(f"Status: {response.status_code}") - print(f"Response: {response.text[:1000]}") - - except Exception as e: - print(f"Error: {e}") - -if __name__ == "__main__": - test_callback() \ No newline at end of file diff --git a/tests/test_callback_structure.py b/tests/test_callback_structure.py deleted file mode 100644 index b345b6e..0000000 --- a/tests/test_callback_structure.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python3 -""" -Test callback structure to verify it works -""" - -import dash -from dash import dcc, html, Input, Output -import plotly.graph_objects as go -from datetime import datetime -import logging - -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Create Dash app -app = dash.Dash(__name__) - -# Simple layout matching the enhanced dashboard structure -app.layout = html.Div([ - html.H1("Callback Structure Test"), - html.Div(id="test-output-1"), - html.Div(id="test-output-2"), - html.Div(id="test-output-3"), - dcc.Graph(id="test-chart"), - dcc.Interval(id='test-interval', interval=3000, n_intervals=0) -]) - -# Callback using the EXACT same structure as enhanced dashboard -@app.callback( - [ - Output('test-output-1', 'children'), - Output('test-output-2', 'children'), - Output('test-output-3', 'children'), - Output('test-chart', 'figure') - ], - [Input('test-interval', 'n_intervals')] -) -def update_test_dashboard(n_intervals): - """Test callback with same structure as enhanced dashboard""" - try: - logger.info(f"Test callback triggered: {n_intervals}") - - # Simple outputs - output1 = f"Output 1: {n_intervals}" - output2 = f"Output 2: {datetime.now().strftime('%H:%M:%S')}" - output3 = f"Output 3: Working" - - # Simple chart - fig = go.Figure() - fig.add_trace(go.Scatter( - x=[1, 2, 3, 4, 5], - y=[n_intervals, n_intervals+1, n_intervals+2, n_intervals+1, n_intervals], - mode='lines', - name='Test Data' - )) - fig.update_layout( - title=f"Test Chart - Update {n_intervals}", - template="plotly_dark" - ) - - logger.info(f"Returning: {output1}, {output2}, {output3},
") - return output1, output2, output3, fig - - except Exception as e: - logger.error(f"Error in test callback: {e}") - import traceback - logger.error(f"Traceback: {traceback.format_exc()}") - - # Return safe fallback - return f"Error: {str(e)}", "Error", "Error", go.Figure() - -if __name__ == "__main__": - logger.info("Starting callback structure test on port 8053...") - app.run(host='127.0.0.1', port=8053, debug=True) \ No newline at end of file diff --git a/tests/test_dashboard_callback.py b/tests/test_dashboard_callback.py deleted file mode 100644 index ab2ec4d..0000000 --- a/tests/test_dashboard_callback.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Dashboard Callback - Simple test to verify Dash callbacks work -""" - -import logging -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -import dash -from dash import dcc, html, Input, Output -import plotly.graph_objects as go -from datetime import datetime - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def create_test_dashboard(): - """Create a simple test dashboard to verify callbacks work""" - - app = dash.Dash(__name__) - - app.layout = html.Div([ - html.H1("๐Ÿงช Test Dashboard - Callback Verification", className="text-center"), - html.Div([ - html.H3(id="current-time", className="text-center"), - html.H4(id="counter", className="text-center"), - dcc.Graph(id="test-chart") - ]), - dcc.Interval( - id='test-interval', - interval=1000, # 1 second - n_intervals=0 - ) - ]) - - @app.callback( - [ - Output('current-time', 'children'), - Output('counter', 'children'), - Output('test-chart', 'figure') - ], - [Input('test-interval', 'n_intervals')] - ) - def update_test_dashboard(n_intervals): - """Test callback function""" - try: - logger.info(f"๐Ÿ”„ Test callback triggered, interval: {n_intervals}") - - current_time = datetime.now().strftime("%H:%M:%S") - counter = f"Updates: {n_intervals}" - - # Create simple test chart - fig = go.Figure() - fig.add_trace(go.Scatter( - x=list(range(n_intervals + 1)), - y=[i**2 for i in range(n_intervals + 1)], - mode='lines+markers', - name='Test Data' - )) - fig.update_layout( - title=f"Test Chart - Update #{n_intervals}", - template="plotly_dark" - ) - - return current_time, counter, fig - - except Exception as e: - logger.error(f"Error in test callback: {e}") - return "Error", "Error", {} - - return app - -def main(): - """Run the test dashboard""" - logger.info("๐Ÿงช Starting test dashboard...") - - try: - app = create_test_dashboard() - logger.info("โœ… Test dashboard created") - - logger.info("๐Ÿš€ Starting test dashboard on http://127.0.0.1:8052") - logger.info("If you see updates every second, callbacks are working!") - logger.info("Press Ctrl+C to stop") - - app.run(host='127.0.0.1', port=8052, debug=True) - - except KeyboardInterrupt: - logger.info("Test dashboard stopped by user") - except Exception as e: - logger.error(f"โŒ Error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_dashboard_requests.py b/tests/test_dashboard_requests.py deleted file mode 100644 index aba78c4..0000000 --- a/tests/test_dashboard_requests.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to make direct requests to the dashboard's callback endpoint -""" - -import requests -import json -import time - -def test_dashboard_callback(): - """Test the dashboard callback endpoint directly""" - - dashboard_url = "http://127.0.0.1:8054" - callback_url = f"{dashboard_url}/_dash-update-component" - - print(f"Testing dashboard at {dashboard_url}") - - # First, check if dashboard is running - try: - response = requests.get(dashboard_url, timeout=5) - print(f"Dashboard status: {response.status_code}") - if response.status_code != 200: - print("Dashboard not responding properly") - return - except Exception as e: - print(f"Error connecting to dashboard: {e}") - return - - # Test callback request for dashboard test - callback_data = { - "output": "current-balance.children", - "outputs": [ - {"id": "current-balance", "property": "children"}, - {"id": "session-duration", "property": "children"}, - {"id": "open-positions", "property": "children"}, - {"id": "live-pnl", "property": "children"}, - {"id": "win-rate", "property": "children"}, - {"id": "total-trades", "property": "children"}, - {"id": "last-action", "property": "children"}, - {"id": "eth-price", "property": "children"}, - {"id": "btc-price", "property": "children"}, - {"id": "main-eth-1s-chart", "property": "figure"}, - {"id": "eth-1m-chart", "property": "figure"}, - {"id": "eth-1h-chart", "property": "figure"}, - {"id": "eth-1d-chart", "property": "figure"}, - {"id": "btc-1s-chart", "property": "figure"}, - {"id": "actions-log", "property": "children"}, - {"id": "debug-status", "property": "children"} - ], - "inputs": [ - {"id": "ultra-fast-interval", "property": "n_intervals", "value": 1} - ], - "changedPropIds": ["ultra-fast-interval.n_intervals"], - "state": [] - } - - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } - - print("\nTesting callback request...") - try: - response = requests.post( - callback_url, - data=json.dumps(callback_data), - headers=headers, - timeout=10 - ) - - print(f"Callback response status: {response.status_code}") - print(f"Response headers: {dict(response.headers)}") - - if response.status_code == 200: - try: - response_data = response.json() - print(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dict'}") - print(f"Response data type: {type(response_data)}") - - if isinstance(response_data, dict) and 'response' in response_data: - print(f"Response contains {len(response_data['response'])} items") - for i, item in enumerate(response_data['response'][:3]): # Show first 3 items - print(f" Item {i}: {type(item)} - {str(item)[:100]}...") - else: - print(f"Full response: {str(response_data)[:500]}...") - - except json.JSONDecodeError as e: - print(f"Error parsing JSON response: {e}") - print(f"Raw response: {response.text[:500]}...") - else: - print(f"Error response: {response.text}") - - except Exception as e: - print(f"Error making callback request: {e}") - -def monitor_dashboard(): - """Monitor dashboard callback requests""" - print("Monitoring dashboard callback requests...") - print("Press Ctrl+C to stop") - - try: - for i in range(10): # Test 10 times - print(f"\n--- Test {i+1} ---") - test_dashboard_callback() - time.sleep(2) - except KeyboardInterrupt: - print("\nMonitoring stopped") - -if __name__ == "__main__": - monitor_dashboard() \ No newline at end of file diff --git a/tests/test_dashboard_simple.py b/tests/test_dashboard_simple.py deleted file mode 100644 index af1a99b..0000000 --- a/tests/test_dashboard_simple.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple Dashboard Test - Isolate dashboard startup issues -""" - -import os -# Fix OpenMP library conflicts before importing other modules -os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -os.environ['OMP_NUM_THREADS'] = '4' - -import sys -import logging -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Setup basic logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_dashboard_startup(): - """Test dashboard creation and startup""" - try: - logger.info("=" * 50) - logger.info("TESTING DASHBOARD STARTUP") - logger.info("=" * 50) - - # Test imports first - logger.info("Step 1: Testing imports...") - from core.config import get_config, setup_logging - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - from core.trading_executor import TradingExecutor - logger.info("โœ“ Core imports successful") - - from web.clean_dashboard import CleanTradingDashboard as TradingDashboard - logger.info("โœ“ Dashboard import successful") - - # Test configuration - logger.info("Step 2: Testing configuration...") - setup_logging() - config = get_config() - logger.info("โœ“ Configuration loaded") - - # Test core component creation - logger.info("Step 3: Testing core component creation...") - data_provider = DataProvider() - logger.info("โœ“ DataProvider created") - - orchestrator = TradingOrchestrator(data_provider=data_provider) - logger.info("โœ“ TradingOrchestrator created") - - trading_executor = TradingExecutor() - logger.info("โœ“ TradingExecutor created") - - # Test dashboard creation - logger.info("Step 4: Testing dashboard creation...") - dashboard = TradingDashboard( - data_provider=data_provider, - orchestrator=orchestrator, - trading_executor=trading_executor - ) - logger.info("โœ“ TradingDashboard created successfully") - - # Test dashboard startup - logger.info("Step 5: Testing dashboard server startup...") - logger.info("Dashboard will start on http://127.0.0.1:8052") - logger.info("Press Ctrl+C to stop the test") - - # Run the dashboard - dashboard.app.run( - host='127.0.0.1', - port=8052, - debug=False, - use_reloader=False - ) - - except Exception as e: - logger.error(f"โŒ Dashboard test failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - - return True - -if __name__ == "__main__": - try: - success = test_dashboard_startup() - if success: - logger.info("โœ“ Dashboard test completed successfully") - else: - logger.error("โŒ Dashboard test failed") - sys.exit(1) - except KeyboardInterrupt: - logger.info("Dashboard test interrupted by user") - except Exception as e: - logger.error(f"Fatal error in dashboard test: {e}") - sys.exit(1) \ No newline at end of file diff --git a/tests/test_dashboard_startup.py b/tests/test_dashboard_startup.py deleted file mode 100644 index c99fc41..0000000 --- a/tests/test_dashboard_startup.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Dashboard Startup - Debug the scalping dashboard startup issue -""" - -import logging -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_dashboard_startup(): - """Test dashboard startup with detailed error reporting""" - try: - logger.info("Testing dashboard startup...") - - # Test imports - logger.info("Testing imports...") - from core.data_provider import DataProvider - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - from web.old_archived.scalping_dashboard import create_scalping_dashboard - logger.info("โœ… All imports successful") - - # Test data provider - logger.info("Creating data provider...") - dp = DataProvider() - logger.info("โœ… Data provider created") - - # Test orchestrator - logger.info("Creating orchestrator...") - orch = EnhancedTradingOrchestrator(dp) - logger.info("โœ… Orchestrator created") - - # Test dashboard creation - logger.info("Creating dashboard...") - dashboard = create_scalping_dashboard(dp, orch) - logger.info("โœ… Dashboard created successfully") - - # Test data fetching - logger.info("Testing data fetching...") - test_data = dp.get_historical_data('ETH/USDT', '1m', limit=5) - if test_data is not None and not test_data.empty: - logger.info(f"โœ… Data fetching works: {len(test_data)} candles") - else: - logger.warning("โš ๏ธ No data returned from data provider") - - # Start dashboard - logger.info("Starting dashboard on http://127.0.0.1:8051") - logger.info("Press Ctrl+C to stop") - dashboard.run(host='127.0.0.1', port=8051, debug=True) - - except KeyboardInterrupt: - logger.info("Dashboard stopped by user") - except Exception as e: - logger.error(f"โŒ Error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_dashboard_startup() \ No newline at end of file diff --git a/tests/test_enhanced_cob_integration.py b/tests/test_enhanced_cob_integration.py deleted file mode 100644 index 95cfb08..0000000 --- a/tests/test_enhanced_cob_integration.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced COB Integration with RL and CNN Models - -This script tests the integration of Consolidated Order Book (COB) data -with the real-time RL and CNN training pipeline. -""" - -import asyncio -import logging -import sys -from pathlib import Path -import numpy as np -import time -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.cob_integration import COBIntegration - -# Setup logging -setup_logging() -logger = logging.getLogger(__name__) - -class COBMLIntegrationTester: - """Test COB integration with ML models""" - - def __init__(self): - self.symbols = ['BTC/USDT', 'ETH/USDT'] - self.data_provider = DataProvider() - self.test_results = {} - - async def test_cob_ml_integration(self): - """Test full COB integration with ML pipeline""" - logger.info("=" * 60) - logger.info("TESTING COB INTEGRATION WITH RL AND CNN MODELS") - logger.info("=" * 60) - - try: - # Initialize enhanced orchestrator with COB integration - logger.info("1. Initializing Enhanced Trading Orchestrator with COB...") - orchestrator = EnhancedTradingOrchestrator( - data_provider=self.data_provider, - symbols=self.symbols, - enhanced_rl_training=True, - model_registry={} - ) - - # Start COB integration - logger.info("2. Starting COB Integration...") - await orchestrator.start_cob_integration() - await asyncio.sleep(5) # Allow startup and data collection - - # Test COB feature generation - logger.info("3. Testing COB feature generation...") - await self._test_cob_features(orchestrator) - - # Test market state with COB data - logger.info("4. Testing market state with COB data...") - await self._test_market_state_cob(orchestrator) - - # Test real-time COB callbacks - logger.info("5. Testing real-time COB callbacks...") - await self._test_realtime_callbacks(orchestrator) - - # Stop COB integration - await orchestrator.stop_cob_integration() - - # Print results - self._print_test_results() - - except Exception as e: - logger.error(f"Error in COB ML integration test: {e}") - import traceback - logger.error(traceback.format_exc()) - - async def _test_cob_features(self, orchestrator): - """Test COB feature availability""" - try: - for symbol in self.symbols: - # Check if COB features are available - cob_features = orchestrator.latest_cob_features.get(symbol) - cob_state = orchestrator.latest_cob_state.get(symbol) - - if cob_features is not None: - logger.info(f"โœ… {symbol}: COB CNN features available - shape: {cob_features.shape}") - self.test_results[f'{symbol}_cob_cnn_features'] = True - else: - logger.warning(f"โš ๏ธ {symbol}: COB CNN features not available") - self.test_results[f'{symbol}_cob_cnn_features'] = False - - if cob_state is not None: - logger.info(f"โœ… {symbol}: COB DQN state available - shape: {cob_state.shape}") - self.test_results[f'{symbol}_cob_dqn_state'] = True - else: - logger.warning(f"โš ๏ธ {symbol}: COB DQN state not available") - self.test_results[f'{symbol}_cob_dqn_state'] = False - - except Exception as e: - logger.error(f"Error testing COB features: {e}") - - async def _test_market_state_cob(self, orchestrator): - """Test market state includes COB data""" - try: - # Generate market states with COB data - from core.universal_data_adapter import UniversalDataAdapter - adapter = UniversalDataAdapter(self.data_provider) - universal_stream = await adapter.get_universal_stream(['BTC/USDT', 'ETH/USDT']) - - market_states = await orchestrator._get_all_market_states_universal(universal_stream) - - for symbol in self.symbols: - if symbol in market_states: - state = market_states[symbol] - - # Check COB integration in market state - tests = [ - ('cob_features', state.cob_features is not None), - ('cob_state', state.cob_state is not None), - ('order_book_imbalance', hasattr(state, 'order_book_imbalance')), - ('liquidity_depth', hasattr(state, 'liquidity_depth')), - ('exchange_diversity', hasattr(state, 'exchange_diversity')), - ('market_impact_estimate', hasattr(state, 'market_impact_estimate')) - ] - - for test_name, passed in tests: - status = "โœ…" if passed else "โŒ" - logger.info(f"{status} {symbol}: {test_name} - {passed}") - self.test_results[f'{symbol}_market_state_{test_name}'] = passed - - # Log COB metrics if available - if hasattr(state, 'order_book_imbalance'): - logger.info(f"๐Ÿ“Š {symbol} COB Metrics:") - logger.info(f" Order Book Imbalance: {state.order_book_imbalance:.4f}") - logger.info(f" Liquidity Depth: ${state.liquidity_depth:,.0f}") - logger.info(f" Exchange Diversity: {state.exchange_diversity}") - logger.info(f" Market Impact (10k): {state.market_impact_estimate:.4f}%") - - except Exception as e: - logger.error(f"Error testing market state COB: {e}") - - async def _test_realtime_callbacks(self, orchestrator): - """Test real-time COB callbacks""" - try: - # Monitor COB callbacks for 10 seconds - initial_features = {s: len(orchestrator.cob_feature_history[s]) for s in self.symbols} - - logger.info("Monitoring COB callbacks for 10 seconds...") - await asyncio.sleep(10) - - final_features = {s: len(orchestrator.cob_feature_history[s]) for s in self.symbols} - - for symbol in self.symbols: - updates = final_features[symbol] - initial_features[symbol] - if updates > 0: - logger.info(f"โœ… {symbol}: Received {updates} COB feature updates") - self.test_results[f'{symbol}_realtime_callbacks'] = True - else: - logger.warning(f"โš ๏ธ {symbol}: No COB feature updates received") - self.test_results[f'{symbol}_realtime_callbacks'] = False - - except Exception as e: - logger.error(f"Error testing realtime callbacks: {e}") - - def _print_test_results(self): - """Print comprehensive test results""" - logger.info("=" * 60) - logger.info("COB ML INTEGRATION TEST RESULTS") - logger.info("=" * 60) - - passed = sum(1 for result in self.test_results.values() if result) - total = len(self.test_results) - - logger.info(f"Overall: {passed}/{total} tests passed ({passed/total*100:.1f}%)") - logger.info("") - - for test_name, result in self.test_results.items(): - status = "โœ… PASS" if result else "โŒ FAIL" - logger.info(f"{status}: {test_name}") - - logger.info("=" * 60) - - if passed == total: - logger.info("๐ŸŽ‰ ALL TESTS PASSED - COB ML INTEGRATION WORKING!") - elif passed > total * 0.8: - logger.info("โš ๏ธ MOSTLY WORKING - Some minor issues detected") - else: - logger.warning("๐Ÿšจ INTEGRATION ISSUES - Significant problems detected") - -async def main(): - """Run COB ML integration tests""" - tester = COBMLIntegrationTester() - await tester.test_cob_ml_integration() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_enhanced_dashboard.py b/tests/test_enhanced_dashboard.py deleted file mode 100644 index a50361d..0000000 --- a/tests/test_enhanced_dashboard.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for enhanced trading dashboard with WebSocket support -""" - -import sys -import logging -from datetime import datetime - -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_dashboard(): - """Test the enhanced dashboard functionality""" - try: - print("="*60) - print("TESTING ENHANCED TRADING DASHBOARD") - print("="*60) - - # Import dashboard - from web.clean_dashboard import CleanTradingDashboard as TradingDashboard -WEBSOCKET_AVAILABLE = True - - print(f"โœ“ Dashboard module imported successfully") - print(f"โœ“ WebSocket support available: {WEBSOCKET_AVAILABLE}") - - # Create dashboard instance - dashboard = TradingDashboard() - - print(f"โœ“ Dashboard instance created") - print(f"โœ“ Tick cache capacity: {dashboard.tick_cache.maxlen} ticks (15 min)") - print(f"โœ“ 1s bars capacity: {dashboard.one_second_bars.maxlen} bars (15 min)") - print(f"โœ“ WebSocket streaming: {dashboard.is_streaming}") - print(f"โœ“ Min confidence threshold: {dashboard.min_confidence_threshold}") - print(f"โœ“ Signal cooldown: {dashboard.signal_cooldown}s") - - # Test tick cache methods - tick_cache = dashboard.get_tick_cache_for_training(minutes=5) - print(f"โœ“ Tick cache method works: {len(tick_cache)} ticks") - - # Test 1s bars method - bars_df = dashboard.get_one_second_bars(count=100) - print(f"โœ“ 1s bars method works: {len(bars_df)} bars") - - # Test chart creation - try: - chart = dashboard._create_price_chart("ETH/USDT") - print(f"โœ“ Price chart creation works") - except Exception as e: - print(f"โš  Price chart creation: {e}") - - print("\n" + "="*60) - print("ENHANCED DASHBOARD FEATURES:") - print("="*60) - print("โœ“ Real-time WebSocket tick streaming (when websocket-client installed)") - print("โœ“ 1-second bar charts with volume") - print("โœ“ 15-minute tick cache for model training") - print("โœ“ Confidence-based signal execution") - print("โœ“ Clear signal vs execution distinction") - print("โœ“ Real-time unrealized P&L display") - print("โœ“ Compact layout with system status icon") - print("โœ“ Scalping-optimized signal generation") - - print("\n" + "="*60) - print("TO START THE DASHBOARD:") - print("="*60) - print("1. Install WebSocket support: pip install websocket-client") - print("2. Run: python -c \"from web.dashboard import TradingDashboard; TradingDashboard().run()\"") - print("3. Open browser: http://127.0.0.1:8050") - print("="*60) - - return True - - except Exception as e: - print(f"โŒ Error testing dashboard: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = test_dashboard() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/test_enhanced_dashboard_integration.py b/tests/test_enhanced_dashboard_integration.py deleted file mode 100644 index 831ecdd..0000000 --- a/tests/test_enhanced_dashboard_integration.py +++ /dev/null @@ -1,305 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Dashboard Integration with RL Training Pipeline - -This script tests the integration between the dashboard and the enhanced RL training pipeline -to verify that: -1. Unified data stream is properly initialized -2. Dashboard receives training data from the enhanced pipeline -3. Data flows correctly between components -4. Enhanced RL training receives comprehensive data -""" - -import asyncio -import logging -import time -import sys -from datetime import datetime -from pathlib import Path - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('test_enhanced_dashboard_integration.log'), - logging.StreamHandler(sys.stdout) - ] -) - -logger = logging.getLogger(__name__) - -# Import components -from core.config import get_config -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream -from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard - -class EnhancedDashboardIntegrationTest: - """Test enhanced dashboard integration with RL training pipeline""" - - def __init__(self): - """Initialize test components""" - self.config = get_config() - self.data_provider = None - self.orchestrator = None - self.unified_stream = None - self.dashboard = None - - # Test results - self.test_results = { - 'data_provider_init': False, - 'orchestrator_init': False, - 'unified_stream_init': False, - 'dashboard_init': False, - 'data_flow_test': False, - 'training_integration_test': False, - 'ui_data_test': False, - 'stream_stats_test': False - } - - logger.info("Enhanced Dashboard Integration Test initialized") - - async def run_tests(self): - """Run all integration tests""" - logger.info("Starting enhanced dashboard integration tests...") - - try: - # Test 1: Initialize components - await self.test_component_initialization() - - # Test 2: Test data flow - await self.test_data_flow() - - # Test 3: Test training integration - await self.test_training_integration() - - # Test 4: Test UI data flow - await self.test_ui_data_flow() - - # Test 5: Test stream statistics - await self.test_stream_statistics() - - # Generate test report - self.generate_test_report() - - except Exception as e: - logger.error(f"Test execution failed: {e}") - raise - - async def test_component_initialization(self): - """Test component initialization""" - logger.info("Testing component initialization...") - - try: - # Initialize data provider - self.data_provider = DataProvider( - symbols=['ETH/USDT', 'BTC/USDT'], - timeframes=['1s', '1m', '1h', '1d'] - ) - self.test_results['data_provider_init'] = True - logger.info("โœ“ Data provider initialized") - - # Initialize orchestrator - self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) - self.test_results['orchestrator_init'] = True - logger.info("โœ“ Enhanced orchestrator initialized") - - # Initialize unified stream - self.unified_stream = UnifiedDataStream(self.data_provider, self.orchestrator) - self.test_results['unified_stream_init'] = True - logger.info("โœ“ Unified data stream initialized") - - # Initialize dashboard - self.dashboard = RealTimeScalpingDashboard( - data_provider=self.data_provider, - orchestrator=self.orchestrator - ) - self.test_results['dashboard_init'] = True - logger.info("โœ“ Dashboard initialized with unified stream integration") - - except Exception as e: - logger.error(f"Component initialization failed: {e}") - raise - - async def test_data_flow(self): - """Test data flow through unified stream""" - logger.info("Testing data flow through unified stream...") - - try: - # Start unified streaming - await self.unified_stream.start_streaming() - - # Wait for data collection - logger.info("Waiting for data collection...") - await asyncio.sleep(10) - - # Check if data is flowing - stream_stats = self.unified_stream.get_stream_stats() - - if stream_stats['tick_cache_size'] > 0: - logger.info(f"โœ“ Tick data flowing: {stream_stats['tick_cache_size']} ticks") - self.test_results['data_flow_test'] = True - else: - logger.warning("โš  No tick data detected") - - if stream_stats['one_second_bars_count'] > 0: - logger.info(f"โœ“ 1s bars generated: {stream_stats['one_second_bars_count']} bars") - else: - logger.warning("โš  No 1s bars generated") - - logger.info(f"Stream statistics: {stream_stats}") - - except Exception as e: - logger.error(f"Data flow test failed: {e}") - raise - - async def test_training_integration(self): - """Test training data integration""" - logger.info("Testing training data integration...") - - try: - # Get latest training data - training_data = self.unified_stream.get_latest_training_data() - - if training_data: - logger.info("โœ“ Training data packet available") - logger.info(f" Tick cache: {len(training_data.tick_cache)} ticks") - logger.info(f" 1s bars: {len(training_data.one_second_bars)} bars") - logger.info(f" Multi-timeframe data: {len(training_data.multi_timeframe_data)} symbols") - logger.info(f" CNN features: {'Available' if training_data.cnn_features else 'Not available'}") - logger.info(f" CNN predictions: {'Available' if training_data.cnn_predictions else 'Not available'}") - logger.info(f" Market state: {'Available' if training_data.market_state else 'Not available'}") - logger.info(f" Universal stream: {'Available' if training_data.universal_stream else 'Not available'}") - - # Check if dashboard can access training data - if hasattr(self.dashboard, 'latest_training_data') and self.dashboard.latest_training_data: - logger.info("โœ“ Dashboard has access to training data") - self.test_results['training_integration_test'] = True - else: - logger.warning("โš  Dashboard does not have training data access") - else: - logger.warning("โš  No training data available") - - except Exception as e: - logger.error(f"Training integration test failed: {e}") - raise - - async def test_ui_data_flow(self): - """Test UI data flow""" - logger.info("Testing UI data flow...") - - try: - # Get latest UI data - ui_data = self.unified_stream.get_latest_ui_data() - - if ui_data: - logger.info("โœ“ UI data packet available") - logger.info(f" Current prices: {ui_data.current_prices}") - logger.info(f" Tick cache size: {ui_data.tick_cache_size}") - logger.info(f" 1s bars count: {ui_data.one_second_bars_count}") - logger.info(f" Streaming status: {ui_data.streaming_status}") - logger.info(f" Training data available: {ui_data.training_data_available}") - - # Check if dashboard can access UI data - if hasattr(self.dashboard, 'latest_ui_data') and self.dashboard.latest_ui_data: - logger.info("โœ“ Dashboard has access to UI data") - self.test_results['ui_data_test'] = True - else: - logger.warning("โš  Dashboard does not have UI data access") - else: - logger.warning("โš  No UI data available") - - except Exception as e: - logger.error(f"UI data flow test failed: {e}") - raise - - async def test_stream_statistics(self): - """Test stream statistics""" - logger.info("Testing stream statistics...") - - try: - # Get comprehensive stream stats - stream_stats = self.unified_stream.get_stream_stats() - - logger.info("Stream Statistics:") - logger.info(f" Total ticks processed: {stream_stats.get('total_ticks_processed', 0)}") - logger.info(f" Total packets sent: {stream_stats.get('total_packets_sent', 0)}") - logger.info(f" Consumers served: {stream_stats.get('consumers_served', 0)}") - logger.info(f" Active consumers: {stream_stats.get('active_consumers', 0)}") - logger.info(f" Total consumers: {stream_stats.get('total_consumers', 0)}") - logger.info(f" Processing errors: {stream_stats.get('processing_errors', 0)}") - logger.info(f" Data quality score: {stream_stats.get('data_quality_score', 0.0)}") - - if stream_stats.get('active_consumers', 0) > 0: - logger.info("โœ“ Stream has active consumers") - self.test_results['stream_stats_test'] = True - else: - logger.warning("โš  No active consumers detected") - - except Exception as e: - logger.error(f"Stream statistics test failed: {e}") - raise - - def generate_test_report(self): - """Generate comprehensive test report""" - logger.info("Generating test report...") - - total_tests = len(self.test_results) - passed_tests = sum(self.test_results.values()) - - logger.info("=" * 60) - logger.info("ENHANCED DASHBOARD INTEGRATION TEST REPORT") - logger.info("=" * 60) - logger.info(f"Test Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - logger.info(f"Total Tests: {total_tests}") - logger.info(f"Passed Tests: {passed_tests}") - logger.info(f"Failed Tests: {total_tests - passed_tests}") - logger.info(f"Success Rate: {(passed_tests / total_tests) * 100:.1f}%") - logger.info("") - - logger.info("Test Results:") - for test_name, result in self.test_results.items(): - status = "โœ“ PASS" if result else "โœ— FAIL" - logger.info(f" {test_name}: {status}") - - logger.info("") - - if passed_tests == total_tests: - logger.info("๐ŸŽ‰ ALL TESTS PASSED! Enhanced dashboard integration is working correctly.") - logger.info("The dashboard now properly integrates with the enhanced RL training pipeline.") - else: - logger.warning("โš  Some tests failed. Please review the integration.") - - logger.info("=" * 60) - - async def cleanup(self): - """Cleanup test resources""" - logger.info("Cleaning up test resources...") - - try: - if self.unified_stream: - await self.unified_stream.stop_streaming() - - if self.dashboard: - self.dashboard.stop_streaming() - - logger.info("โœ“ Cleanup completed") - - except Exception as e: - logger.error(f"Cleanup failed: {e}") - -async def main(): - """Main test execution""" - test = EnhancedDashboardIntegrationTest() - - try: - await test.run_tests() - except Exception as e: - logger.error(f"Test execution failed: {e}") - finally: - await test.cleanup() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_enhanced_fee_tracking.py b/tests/test_enhanced_fee_tracking.py deleted file mode 100644 index 9046b01..0000000 --- a/tests/test_enhanced_fee_tracking.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify enhanced fee tracking with maker/taker fees -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -import logging -from datetime import datetime, timezone -from web.clean_dashboard import CleanTradingDashboard as TradingDashboard -from core.data_provider import DataProvider - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_enhanced_fee_tracking(): - """Test enhanced fee tracking with maker/taker fees""" - - logger.info("Testing enhanced fee tracking...") - - # Create dashboard instance - data_provider = DataProvider() - dashboard = TradingDashboard(data_provider=data_provider) - - # Create test trading decisions with different fee types - test_decisions = [ - { - 'action': 'BUY', - 'symbol': 'ETH/USDT', - 'price': 3500.0, - 'confidence': 0.8, - 'timestamp': datetime.now(timezone.utc), - 'order_type': 'market', # Should use taker fee - 'filled_as_maker': False - }, - { - 'action': 'SELL', - 'symbol': 'ETH/USDT', - 'price': 3520.0, - 'confidence': 0.9, - 'timestamp': datetime.now(timezone.utc), - 'order_type': 'limit', # Should use maker fee if filled as maker - 'filled_as_maker': True - } - ] - - # Process the trading decisions - for i, decision in enumerate(test_decisions): - logger.info(f"Processing decision {i+1}: {decision['action']} @ ${decision['price']}") - dashboard._process_trading_decision(decision) - - # Check session trades - if dashboard.session_trades: - latest_trade = dashboard.session_trades[-1] - fee_type = latest_trade.get('fee_type', 'unknown') - fee_rate = latest_trade.get('fee_rate', 0) - fees = latest_trade.get('fees', 0) - - logger.info(f" Trade recorded: {latest_trade.get('position_action', 'unknown')}") - logger.info(f" Fee Type: {fee_type}") - logger.info(f" Fee Rate: {fee_rate*100:.3f}%") - logger.info(f" Fee Amount: ${fees:.4f}") - - # Check closed trades - if dashboard.closed_trades: - logger.info(f"\nClosed trades: {len(dashboard.closed_trades)}") - for trade in dashboard.closed_trades: - logger.info(f" Trade #{trade['trade_id']}: {trade['side']}") - logger.info(f" Fee Type: {trade.get('fee_type', 'unknown')}") - logger.info(f" Fee Rate: {trade.get('fee_rate', 0)*100:.3f}%") - logger.info(f" Total Fees: ${trade.get('fees', 0):.4f}") - logger.info(f" Net P&L: ${trade.get('net_pnl', 0):.2f}") - - # Test session performance with fee breakdown - logger.info("\nTesting session performance display...") - performance = dashboard._create_session_performance() - logger.info(f"Session performance components: {len(performance)}") - - # Test closed trades table - logger.info("\nTesting enhanced trades table...") - table_components = dashboard._create_closed_trades_table() - logger.info(f"Table components: {len(table_components)}") - - logger.info("Enhanced fee tracking test completed!") - - return True - -if __name__ == "__main__": - test_enhanced_fee_tracking() \ No newline at end of file diff --git a/tests/test_enhanced_improvements.py b/tests/test_enhanced_improvements.py deleted file mode 100644 index 5e1409a..0000000 --- a/tests/test_enhanced_improvements.py +++ /dev/null @@ -1,243 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Trading System Improvements - -This script tests: -1. Color-coded position display ([LONG] green, [SHORT] red) -2. Enhanced model training detection and retrospective learning -3. Lower confidence thresholds for closing positions (0.25 vs 0.6 for opening) -4. Perfect opportunity detection and learning -""" - -import asyncio -import logging -import time -from datetime import datetime, timedelta -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator, TradingAction -from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard, TradingSession - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_color_coded_positions(): - """Test color-coded position display functionality""" - logger.info("=== Testing Color-Coded Position Display ===") - - # Create trading session - session = TradingSession() - - # Simulate some positions - session.positions = { - 'ETH/USDT': { - 'side': 'LONG', - 'size': 0.1, - 'entry_price': 2558.15 - }, - 'BTC/USDT': { - 'side': 'SHORT', - 'size': 0.05, - 'entry_price': 45123.45 - } - } - - logger.info("Created test positions:") - logger.info(f"ETH/USDT: LONG 0.1 @ $2558.15") - logger.info(f"BTC/USDT: SHORT 0.05 @ $45123.45") - - # Test position display logic (simulating dashboard logic) - live_prices = {'ETH/USDT': 2565.30, 'BTC/USDT': 45050.20} - - for symbol, pos in session.positions.items(): - side = pos['side'] - size = pos['size'] - entry_price = pos['entry_price'] - current_price = live_prices.get(symbol, entry_price) - - # Calculate unrealized P&L - if side == 'LONG': - unrealized_pnl = (current_price - entry_price) * size - color_class = "text-success" # Green for LONG - side_display = "[LONG]" - else: # SHORT - unrealized_pnl = (entry_price - current_price) * size - color_class = "text-danger" # Red for SHORT - side_display = "[SHORT]" - - position_text = f"{side_display} {size:.3f} @ ${entry_price:.2f} | P&L: ${unrealized_pnl:+.2f}" - logger.info(f"Position Display: {position_text} (Color: {color_class})") - - logger.info("โœ… Color-coded position display test completed") - -def test_confidence_thresholds(): - """Test different confidence thresholds for opening vs closing""" - logger.info("=== Testing Confidence Thresholds ===") - - # Create orchestrator - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - logger.info(f"Opening threshold: {orchestrator.confidence_threshold_open}") - logger.info(f"Closing threshold: {orchestrator.confidence_threshold_close}") - - # Test opening action with medium confidence - test_confidence = 0.45 - logger.info(f"\nTesting opening action with confidence {test_confidence}") - - if test_confidence >= orchestrator.confidence_threshold_open: - logger.info("โœ… Would OPEN position (confidence above opening threshold)") - else: - logger.info("โŒ Would NOT open position (confidence below opening threshold)") - - # Test closing action with same confidence - logger.info(f"Testing closing action with confidence {test_confidence}") - - if test_confidence >= orchestrator.confidence_threshold_close: - logger.info("โœ… Would CLOSE position (confidence above closing threshold)") - else: - logger.info("โŒ Would NOT close position (confidence below closing threshold)") - - logger.info("โœ… Confidence threshold test completed") - -def test_retrospective_learning(): - """Test retrospective learning and perfect opportunity detection""" - logger.info("=== Testing Retrospective Learning ===") - - # Create orchestrator - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Simulate perfect moves - from core.enhanced_orchestrator import PerfectMove - - perfect_move = PerfectMove( - symbol='ETH/USDT', - timeframe='1m', - timestamp=datetime.now(), - optimal_action='BUY', - actual_outcome=0.025, # 2.5% price increase - market_state_before=None, - market_state_after=None, - confidence_should_have_been=0.85 - ) - - orchestrator.perfect_moves.append(perfect_move) - orchestrator.retrospective_learning_active = True - - logger.info(f"Added perfect move: {perfect_move.optimal_action} {perfect_move.symbol}") - logger.info(f"Outcome: {perfect_move.actual_outcome*100:+.2f}%") - logger.info(f"Confidence should have been: {perfect_move.confidence_should_have_been:.3f}") - - # Test performance metrics - metrics = orchestrator.get_performance_metrics() - retro_metrics = metrics['retrospective_learning'] - - logger.info(f"Retrospective learning active: {retro_metrics['active']}") - logger.info(f"Recent perfect moves: {retro_metrics['perfect_moves_recent']}") - logger.info(f"Average confidence needed: {retro_metrics['avg_confidence_needed']:.3f}") - - logger.info("โœ… Retrospective learning test completed") - -async def test_tick_pattern_detection(): - """Test tick pattern detection for violent moves""" - logger.info("=== Testing Tick Pattern Detection ===") - - # Create orchestrator - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Simulate violent tick - from core.tick_aggregator import RawTick - - violent_tick = RawTick( - timestamp=datetime.now(), - price=2560.0, - volume=1000.0, - quantity=0.5, - side='buy', - trade_id='test123', - time_since_last=25.0, # Very fast tick (25ms) - price_change=5.0, # $5 price jump - volume_intensity=3.5 # High volume - ) - - # Add symbol attribute for testing - violent_tick.symbol = 'ETH/USDT' - - logger.info(f"Simulating violent tick:") - logger.info(f"Price change: ${violent_tick.price_change:+.2f}") - logger.info(f"Time since last: {violent_tick.time_since_last:.0f}ms") - logger.info(f"Volume intensity: {violent_tick.volume_intensity:.1f}x") - - # Process the tick - orchestrator._handle_raw_tick(violent_tick) - - # Check if perfect move was created - if orchestrator.perfect_moves: - latest_move = orchestrator.perfect_moves[-1] - logger.info(f"โœ… Perfect move detected: {latest_move.optimal_action}") - logger.info(f"Confidence: {latest_move.confidence_should_have_been:.3f}") - else: - logger.info("โŒ No perfect move detected") - - logger.info("โœ… Tick pattern detection test completed") - -def test_dashboard_integration(): - """Test dashboard integration with new features""" - logger.info("=== Testing Dashboard Integration ===") - - # Create components - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Test model training status - metrics = orchestrator.get_performance_metrics() - - logger.info("Model Training Metrics:") - logger.info(f"Perfect moves: {metrics['perfect_moves']}") - logger.info(f"RL queue size: {metrics['rl_queue_size']}") - logger.info(f"Retrospective learning: {metrics['retrospective_learning']}") - logger.info(f"Position tracking: {metrics['position_tracking']}") - logger.info(f"Thresholds: {metrics['thresholds']}") - - logger.info("โœ… Dashboard integration test completed") - -async def main(): - """Run all tests""" - logger.info("๐Ÿš€ Starting Enhanced Trading System Tests") - logger.info("=" * 60) - - try: - # Run tests - test_color_coded_positions() - print() - - test_confidence_thresholds() - print() - - test_retrospective_learning() - print() - - await test_tick_pattern_detection() - print() - - test_dashboard_integration() - print() - - logger.info("=" * 60) - logger.info("๐ŸŽ‰ All tests completed successfully!") - logger.info("Key improvements verified:") - logger.info("โœ… Color-coded positions ([LONG] green, [SHORT] red)") - logger.info("โœ… Lower closing thresholds (0.25 vs 0.6)") - logger.info("โœ… Retrospective learning on perfect opportunities") - logger.info("โœ… Enhanced model training detection") - logger.info("โœ… Violent move pattern detection") - - except Exception as e: - logger.error(f"โŒ Test failed: {e}") - import traceback - logger.error(traceback.format_exc()) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_enhanced_orchestrator_fixed.py b/tests/test_enhanced_orchestrator_fixed.py deleted file mode 100644 index 9062ba5..0000000 --- a/tests/test_enhanced_orchestrator_fixed.py +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Orchestrator - Bypass COB Integration Issues - -Simple test to verify enhanced orchestrator methods work -and the dashboard can use them for comprehensive RL training. -""" - -import sys -import os -from pathlib import Path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -def test_enhanced_orchestrator_bypass_cob(): - """Test enhanced orchestrator without COB integration""" - print("=" * 60) - print("TESTING ENHANCED ORCHESTRATOR (BYPASS COB INTEGRATION)") - print("=" * 60) - - try: - # Import required modules - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - print("โœ“ Basic imports successful") - - # Create basic orchestrator first - dp = DataProvider() - basic_orch = TradingOrchestrator(dp) - print("โœ“ Basic TradingOrchestrator created") - - # Test basic orchestrator methods - basic_methods = ['build_comprehensive_rl_state', 'calculate_enhanced_pivot_reward'] - print("\nBasic TradingOrchestrator methods:") - for method in basic_methods: - has_method = hasattr(basic_orch, method) - print(f" {method}: {'โœ“' if has_method else 'โœ—'}") - - # Now test by manually adding the missing methods to basic orchestrator - print("\n" + "-" * 50) - print("ADDING MISSING METHODS TO BASIC ORCHESTRATOR") - print("-" * 50) - - # Add the missing methods manually - def build_comprehensive_rl_state_fallback(self, symbol: str) -> list: - """Fallback comprehensive RL state builder""" - try: - # Create a comprehensive state with ~13,400 features - comprehensive_features = [] - - # ETH Tick Features (3000) - comprehensive_features.extend([0.0] * 3000) - - # ETH Multi-timeframe OHLCV (8000) - comprehensive_features.extend([0.0] * 8000) - - # BTC Reference Data (1000) - comprehensive_features.extend([0.0] * 1000) - - # CNN Hidden Features (1000) - comprehensive_features.extend([0.0] * 1000) - - # Pivot Analysis (300) - comprehensive_features.extend([0.0] * 300) - - # Market Microstructure (100) - comprehensive_features.extend([0.0] * 100) - - print(f"โœ“ Built comprehensive RL state: {len(comprehensive_features)} features") - return comprehensive_features - - except Exception as e: - print(f"โœ— Error building comprehensive RL state: {e}") - return None - - def calculate_enhanced_pivot_reward_fallback(self, trade_decision, market_data, trade_outcome) -> float: - """Fallback enhanced pivot reward calculation""" - try: - # Calculate enhanced reward based on trade metrics - base_pnl = trade_outcome.get('net_pnl', 0) - base_reward = base_pnl / 100.0 # Normalize - - # Add pivot analysis bonus - pivot_bonus = 0.1 if base_pnl > 0 else -0.05 - - enhanced_reward = base_reward + pivot_bonus - print(f"โœ“ Enhanced pivot reward calculated: {enhanced_reward:.4f}") - return enhanced_reward - - except Exception as e: - print(f"โœ— Error calculating enhanced pivot reward: {e}") - return 0.0 - - # Bind methods to the orchestrator instance - import types - basic_orch.build_comprehensive_rl_state = types.MethodType(build_comprehensive_rl_state_fallback, basic_orch) - basic_orch.calculate_enhanced_pivot_reward = types.MethodType(calculate_enhanced_pivot_reward_fallback, basic_orch) - - print("\nโœ“ Enhanced methods added to basic orchestrator") - - # Test the enhanced methods - print("\nTesting enhanced methods:") - - # Test comprehensive RL state building - state = basic_orch.build_comprehensive_rl_state('ETH/USDT') - print(f" Comprehensive RL state: {'โœ“' if state and len(state) > 10000 else 'โœ—'} ({len(state) if state else 0} features)") - - # Test enhanced reward calculation - mock_trade = {'net_pnl': 50.0} - reward = basic_orch.calculate_enhanced_pivot_reward({}, {}, mock_trade) - print(f" Enhanced pivot reward: {'โœ“' if reward != 0 else 'โœ—'} (reward: {reward})") - - print("\n" + "=" * 60) - print("โœ… ENHANCED ORCHESTRATOR METHODS WORKING") - print("โœ… COMPREHENSIVE RL STATE: 13,400+ FEATURES") - print("โœ… ENHANCED PIVOT REWARDS: FUNCTIONAL") - print("โœ… DASHBOARD CAN NOW USE ENHANCED FEATURES") - print("=" * 60) - - return True - - except Exception as e: - print(f"\nโŒ ERROR: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = test_enhanced_orchestrator_bypass_cob() - if success: - print("\n๐ŸŽ‰ PIPELINE FIXES VERIFIED - READY FOR REAL-TIME TRAINING!") - else: - print("\n๐Ÿ’ฅ PIPELINE FIXES NEED MORE WORK") \ No newline at end of file diff --git a/tests/test_enhanced_order_flow_integration.py b/tests/test_enhanced_order_flow_integration.py deleted file mode 100644 index 2d197b2..0000000 --- a/tests/test_enhanced_order_flow_integration.py +++ /dev/null @@ -1,318 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Order Flow Integration - -Tests the enhanced order flow analysis capabilities including: -- Aggressive vs passive participant ratios -- Institutional vs retail trade detection -- Market maker vs taker flow analysis -- Order flow intensity measurements -- Liquidity consumption and price impact analysis -- Block trade and iceberg order detection -- High-frequency trading activity detection - -Usage: - python test_enhanced_order_flow_integration.py -""" - -import asyncio -import logging -import time -import json -from datetime import datetime, timedelta -from core.bookmap_integration import BookmapIntegration - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(), - logging.FileHandler('enhanced_order_flow_test.log') - ] -) -logger = logging.getLogger(__name__) - -class EnhancedOrderFlowTester: - """Test enhanced order flow analysis features""" - - def __init__(self): - self.bookmap = None - self.symbols = ['ETHUSDT', 'BTCUSDT'] - self.test_duration = 300 # 5 minutes - self.metrics_history = [] - - async def setup_integration(self): - """Initialize the Bookmap integration""" - try: - logger.info("Setting up Enhanced Order Flow Integration...") - self.bookmap = BookmapIntegration(symbols=self.symbols) - - # Add callbacks for testing - self.bookmap.add_cnn_callback(self._cnn_callback) - self.bookmap.add_dqn_callback(self._dqn_callback) - - logger.info(f"Integration setup complete for symbols: {self.symbols}") - return True - - except Exception as e: - logger.error(f"Failed to setup integration: {e}") - return False - - def _cnn_callback(self, symbol: str, features: dict): - """CNN callback for testing""" - logger.debug(f"CNN features received for {symbol}: {len(features.get('features', []))} dimensions") - - def _dqn_callback(self, symbol: str, state: dict): - """DQN callback for testing""" - logger.debug(f"DQN state received for {symbol}: {len(state.get('state', []))} dimensions") - - async def start_streaming(self): - """Start real-time data streaming""" - try: - logger.info("Starting enhanced order flow streaming...") - await self.bookmap.start_streaming() - logger.info("Streaming started successfully") - return True - - except Exception as e: - logger.error(f"Failed to start streaming: {e}") - return False - - async def monitor_order_flow(self): - """Monitor and analyze order flow for test duration""" - logger.info(f"Monitoring enhanced order flow for {self.test_duration} seconds...") - - start_time = time.time() - iteration = 0 - - while time.time() - start_time < self.test_duration: - try: - iteration += 1 - - # Test each symbol - for symbol in self.symbols: - await self._analyze_symbol_flow(symbol, iteration) - - # Wait 10 seconds between analyses - await asyncio.sleep(10) - - except Exception as e: - logger.error(f"Error during monitoring iteration {iteration}: {e}") - await asyncio.sleep(5) - - logger.info("Order flow monitoring completed") - - async def _analyze_symbol_flow(self, symbol: str, iteration: int): - """Analyze order flow for a specific symbol""" - try: - # Get enhanced order flow metrics - flow_metrics = self.bookmap.get_enhanced_order_flow_metrics(symbol) - if not flow_metrics: - logger.warning(f"No flow metrics available for {symbol}") - return - - # Log key metrics - aggressive_passive = flow_metrics['aggressive_passive'] - institutional_retail = flow_metrics['institutional_retail'] - flow_intensity = flow_metrics['flow_intensity'] - price_impact = flow_metrics['price_impact'] - maker_taker = flow_metrics['maker_taker_flow'] - - logger.info(f"\n=== {symbol} Order Flow Analysis (Iteration {iteration}) ===") - logger.info(f"Aggressive Ratio: {aggressive_passive['aggressive_ratio']:.2%}") - logger.info(f"Passive Ratio: {aggressive_passive['passive_ratio']:.2%}") - logger.info(f"Institutional Ratio: {institutional_retail['institutional_ratio']:.2%}") - logger.info(f"Retail Ratio: {institutional_retail['retail_ratio']:.2%}") - logger.info(f"Flow Intensity: {flow_intensity['current_intensity']:.2f} ({flow_intensity['intensity_category']})") - logger.info(f"Price Impact: {price_impact['avg_impact']:.2f} bps ({price_impact['impact_category']})") - logger.info(f"Buy Pressure: {maker_taker['buy_pressure']:.2%}") - logger.info(f"Sell Pressure: {maker_taker['sell_pressure']:.2%}") - - # Trade size analysis - size_dist = flow_metrics['size_distribution'] - total_trades = sum(size_dist.values()) - if total_trades > 0: - logger.info(f"Trade Size Distribution (last 100 trades):") - logger.info(f" Micro (<$1K): {size_dist.get('micro', 0)} ({size_dist.get('micro', 0)/total_trades:.1%})") - logger.info(f" Small ($1K-$10K): {size_dist.get('small', 0)} ({size_dist.get('small', 0)/total_trades:.1%})") - logger.info(f" Medium ($10K-$50K): {size_dist.get('medium', 0)} ({size_dist.get('medium', 0)/total_trades:.1%})") - logger.info(f" Large ($50K-$100K): {size_dist.get('large', 0)} ({size_dist.get('large', 0)/total_trades:.1%})") - logger.info(f" Block (>$100K): {size_dist.get('block', 0)} ({size_dist.get('block', 0)/total_trades:.1%})") - - # Volume analysis - if 'volume_stats' in flow_metrics and flow_metrics['volume_stats']: - volume_stats = flow_metrics['volume_stats'] - logger.info(f"24h Volume: {volume_stats.get('volume_24h', 0):,.0f}") - logger.info(f"24h Quote Volume: ${volume_stats.get('quote_volume_24h', 0):,.0f}") - - # Store metrics for analysis - self.metrics_history.append({ - 'timestamp': datetime.now(), - 'symbol': symbol, - 'iteration': iteration, - 'metrics': flow_metrics - }) - - # Test CNN and DQN features - await self._test_model_features(symbol) - - except Exception as e: - logger.error(f"Error analyzing flow for {symbol}: {e}") - - async def _test_model_features(self, symbol: str): - """Test CNN and DQN feature extraction""" - try: - # Test CNN features - cnn_features = self.bookmap.get_cnn_features(symbol) - if cnn_features is not None: - logger.info(f"CNN Features: {len(cnn_features)} dimensions") - logger.info(f" Order book features: {cnn_features[:80].mean():.4f} (avg)") - logger.info(f" Liquidity metrics: {cnn_features[80:90].mean():.4f} (avg)") - logger.info(f" Imbalance features: {cnn_features[90:95].mean():.4f} (avg)") - logger.info(f" Enhanced flow features: {cnn_features[95:].mean():.4f} (avg)") - - # Test DQN features - dqn_features = self.bookmap.get_dqn_state_features(symbol) - if dqn_features is not None: - logger.info(f"DQN State: {len(dqn_features)} dimensions") - logger.info(f" Order book state: {dqn_features[:20].mean():.4f} (avg)") - logger.info(f" Market indicators: {dqn_features[20:30].mean():.4f} (avg)") - logger.info(f" Enhanced flow state: {dqn_features[30:].mean():.4f} (avg)") - - # Test dashboard data - dashboard_data = self.bookmap.get_dashboard_data(symbol) - if dashboard_data and 'enhanced_order_flow' in dashboard_data: - logger.info("Dashboard data includes enhanced order flow metrics") - - except Exception as e: - logger.error(f"Error testing model features for {symbol}: {e}") - - async def stop_streaming(self): - """Stop data streaming""" - try: - logger.info("Stopping order flow streaming...") - await self.bookmap.stop_streaming() - logger.info("Streaming stopped") - - except Exception as e: - logger.error(f"Error stopping streaming: {e}") - - def generate_summary_report(self): - """Generate a summary report of the test""" - try: - logger.info("\n" + "="*60) - logger.info("ENHANCED ORDER FLOW ANALYSIS SUMMARY") - logger.info("="*60) - - if not self.metrics_history: - logger.warning("No metrics data collected during test") - return - - # Group by symbol - symbol_data = {} - for entry in self.metrics_history: - symbol = entry['symbol'] - if symbol not in symbol_data: - symbol_data[symbol] = [] - symbol_data[symbol].append(entry) - - # Analyze each symbol - for symbol, data in symbol_data.items(): - logger.info(f"\n--- {symbol} Analysis ---") - logger.info(f"Data points collected: {len(data)}") - - if len(data) > 0: - # Calculate averages - avg_aggressive = sum(d['metrics']['aggressive_passive']['aggressive_ratio'] for d in data) / len(data) - avg_institutional = sum(d['metrics']['institutional_retail']['institutional_ratio'] for d in data) / len(data) - avg_intensity = sum(d['metrics']['flow_intensity']['current_intensity'] for d in data) / len(data) - avg_impact = sum(d['metrics']['price_impact']['avg_impact'] for d in data) / len(data) - - logger.info(f"Average Aggressive Ratio: {avg_aggressive:.2%}") - logger.info(f"Average Institutional Ratio: {avg_institutional:.2%}") - logger.info(f"Average Flow Intensity: {avg_intensity:.2f}") - logger.info(f"Average Price Impact: {avg_impact:.2f} bps") - - # Detect trends - first_half = data[:len(data)//2] if len(data) > 1 else data - second_half = data[len(data)//2:] if len(data) > 1 else data - - if len(first_half) > 0 and len(second_half) > 0: - first_aggressive = sum(d['metrics']['aggressive_passive']['aggressive_ratio'] for d in first_half) / len(first_half) - second_aggressive = sum(d['metrics']['aggressive_passive']['aggressive_ratio'] for d in second_half) / len(second_half) - - trend = "increasing" if second_aggressive > first_aggressive else "decreasing" - logger.info(f"Aggressive trading trend: {trend}") - - logger.info("\n" + "="*60) - logger.info("Test completed successfully!") - logger.info("Enhanced order flow analysis is working correctly.") - logger.info("="*60) - - except Exception as e: - logger.error(f"Error generating summary report: {e}") - -async def run_enhanced_order_flow_test(): - """Run the complete enhanced order flow test""" - tester = EnhancedOrderFlowTester() - - try: - # Setup - logger.info("Starting Enhanced Order Flow Integration Test") - logger.info("This test will demonstrate:") - logger.info("- Aggressive vs Passive participant analysis") - logger.info("- Institutional vs Retail trade detection") - logger.info("- Order flow intensity measurements") - logger.info("- Price impact and liquidity consumption analysis") - logger.info("- Block trade and iceberg order detection") - logger.info("- Enhanced CNN and DQN feature extraction") - - if not await tester.setup_integration(): - logger.error("Failed to setup integration") - return False - - # Start streaming - if not await tester.start_streaming(): - logger.error("Failed to start streaming") - return False - - # Wait for initial data - logger.info("Waiting 30 seconds for initial data...") - await asyncio.sleep(30) - - # Monitor order flow - await tester.monitor_order_flow() - - # Generate report - tester.generate_summary_report() - - return True - - except Exception as e: - logger.error(f"Test failed: {e}") - return False - - finally: - # Cleanup - try: - await tester.stop_streaming() - except Exception as e: - logger.error(f"Error during cleanup: {e}") - -if __name__ == "__main__": - try: - # Run the test - success = asyncio.run(run_enhanced_order_flow_test()) - - if success: - print("\nโœ… Enhanced Order Flow Integration Test PASSED") - print("All enhanced order flow analysis features are working correctly!") - else: - print("\nโŒ Enhanced Order Flow Integration Test FAILED") - print("Check the logs for details.") - - except KeyboardInterrupt: - print("\nโš ๏ธ Test interrupted by user") - except Exception as e: - print(f"\n๐Ÿ’ฅ Test crashed: {e}") \ No newline at end of file diff --git a/tests/test_enhanced_pivot_rl_system.py b/tests/test_enhanced_pivot_rl_system.py deleted file mode 100644 index c435dd7..0000000 --- a/tests/test_enhanced_pivot_rl_system.py +++ /dev/null @@ -1,320 +0,0 @@ -""" -Test Enhanced Pivot-Based RL System - -Tests the new system with: -- Different thresholds for entry vs exit -- Pivot-based rewards -- CNN predictions for early pivot detection -- Uninvested rewards -""" - -import logging -import sys -import numpy as np -import pandas as pd -from datetime import datetime, timedelta -from typing import Dict, Any - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', - stream=sys.stdout -) - -logger = logging.getLogger(__name__) - -# Add project root to Python path -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from training.enhanced_pivot_rl_trainer import EnhancedPivotRLTrainer, create_enhanced_pivot_trainer - -def test_enhanced_pivot_thresholds(): - """Test the enhanced pivot-based threshold system""" - logger.info("=== Testing Enhanced Pivot-Based Thresholds ===") - - try: - # Create components - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True - ) - - # Test threshold initialization - thresholds = orchestrator.pivot_rl_trainer.get_current_thresholds() - logger.info(f"Initial thresholds:") - logger.info(f" Entry: {thresholds['entry_threshold']:.3f}") - logger.info(f" Exit: {thresholds['exit_threshold']:.3f}") - logger.info(f" Uninvested: {thresholds['uninvested_threshold']:.3f}") - - # Verify entry threshold is higher than exit threshold - assert thresholds['entry_threshold'] > thresholds['exit_threshold'], "Entry threshold should be higher than exit" - logger.info("โœ… Entry threshold correctly higher than exit threshold") - - return True - - except Exception as e: - logger.error(f"Error testing thresholds: {e}") - return False - -def test_pivot_reward_calculation(): - """Test the pivot-based reward calculation""" - logger.info("=== Testing Pivot-Based Reward Calculation ===") - - try: - # Create enhanced pivot trainer - data_provider = DataProvider() - pivot_trainer = create_enhanced_pivot_trainer(data_provider) - - # Create mock trade decision and outcome - trade_decision = { - 'action': 'BUY', - 'confidence': 0.75, - 'price': 2500.0, - 'timestamp': datetime.now() - } - - trade_outcome = { - 'net_pnl': 15.50, # Profitable trade - 'exit_price': 2515.0, - 'duration': timedelta(minutes=45) - } - - # Create mock market data - market_data = pd.DataFrame({ - 'open': np.random.normal(2500, 10, 100), - 'high': np.random.normal(2510, 10, 100), - 'low': np.random.normal(2490, 10, 100), - 'close': np.random.normal(2500, 10, 100), - 'volume': np.random.normal(1000, 100, 100) - }) - market_data.index = pd.date_range(start=datetime.now() - timedelta(hours=2), periods=100, freq='1min') - - # Calculate reward - reward = pivot_trainer.calculate_pivot_based_reward( - trade_decision, market_data, trade_outcome - ) - - logger.info(f"Calculated pivot-based reward: {reward:.3f}") - - # Test should return a reasonable reward for profitable trade - assert -15.0 <= reward <= 10.0, f"Reward {reward} outside expected range" - logger.info("โœ… Pivot-based reward calculation working") - - # Test uninvested reward - low_conf_decision = { - 'action': 'HOLD', - 'confidence': 0.35, # Below uninvested threshold - 'price': 2500.0, - 'timestamp': datetime.now() - } - - uninvested_reward = pivot_trainer._calculate_uninvested_rewards(low_conf_decision, 0.35) - logger.info(f"Uninvested reward for low confidence: {uninvested_reward:.3f}") - - assert uninvested_reward > 0, "Should get positive reward for staying uninvested with low confidence" - logger.info("โœ… Uninvested rewards working correctly") - - return True - - except Exception as e: - logger.error(f"Error testing pivot rewards: {e}") - return False - -def test_confidence_adjustment(): - """Test confidence-based reward adjustments""" - logger.info("=== Testing Confidence-Based Adjustments ===") - - try: - pivot_trainer = create_enhanced_pivot_trainer() - - # Test overconfidence penalty on loss - high_conf_loss = { - 'action': 'BUY', - 'confidence': 0.85, # High confidence - 'price': 2500.0, - 'timestamp': datetime.now() - } - - loss_outcome = { - 'net_pnl': -25.0, # Loss - 'exit_price': 2475.0, - 'duration': timedelta(hours=3) - } - - confidence_adjustment = pivot_trainer._calculate_confidence_adjustment( - high_conf_loss, loss_outcome - ) - - logger.info(f"Confidence adjustment for overconfident loss: {confidence_adjustment:.3f}") - assert confidence_adjustment < 0, "Should penalize overconfidence on losses" - - # Test underconfidence penalty on win - low_conf_win = { - 'action': 'BUY', - 'confidence': 0.35, # Low confidence - 'price': 2500.0, - 'timestamp': datetime.now() - } - - win_outcome = { - 'net_pnl': 20.0, # Profit - 'exit_price': 2520.0, - 'duration': timedelta(minutes=30) - } - - confidence_adjustment_2 = pivot_trainer._calculate_confidence_adjustment( - low_conf_win, win_outcome - ) - - logger.info(f"Confidence adjustment for underconfident win: {confidence_adjustment_2:.3f}") - # Should be small penalty or zero - - logger.info("โœ… Confidence adjustments working correctly") - return True - - except Exception as e: - logger.error(f"Error testing confidence adjustments: {e}") - return False - -def test_dynamic_threshold_updates(): - """Test dynamic threshold updating based on performance""" - logger.info("=== Testing Dynamic Threshold Updates ===") - - try: - pivot_trainer = create_enhanced_pivot_trainer() - - # Get initial thresholds - initial_thresholds = pivot_trainer.get_current_thresholds() - logger.info(f"Initial thresholds: {initial_thresholds}") - - # Simulate some poor performance (low win rate) - for i in range(25): - outcome = { - 'timestamp': datetime.now(), - 'action': 'BUY', - 'confidence': 0.6, - 'net_pnl': -5.0 if i < 20 else 10.0, # 20% win rate - 'reward': -1.0 if i < 20 else 2.0, - 'duration': timedelta(hours=2) - } - pivot_trainer.trade_outcomes.append(outcome) - - # Update thresholds - pivot_trainer.update_thresholds_based_on_performance() - - # Get updated thresholds - updated_thresholds = pivot_trainer.get_current_thresholds() - logger.info(f"Updated thresholds after poor performance: {updated_thresholds}") - - # Entry threshold should increase (more selective) after poor performance - assert updated_thresholds['entry_threshold'] >= initial_thresholds['entry_threshold'], \ - "Entry threshold should increase after poor performance" - - logger.info("โœ… Dynamic threshold updates working correctly") - return True - - except Exception as e: - logger.error(f"Error testing dynamic thresholds: {e}") - return False - -def test_cnn_integration(): - """Test CNN integration for pivot predictions""" - logger.info("=== Testing CNN Integration ===") - - try: - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True - ) - - # Check if Williams structure is initialized with CNN - williams = orchestrator.pivot_rl_trainer.williams - logger.info(f"Williams CNN enabled: {williams.enable_cnn_feature}") - logger.info(f"Williams CNN model available: {williams.cnn_model is not None}") - - # Test CNN threshold adjustment - from core.enhanced_orchestrator import MarketState - from datetime import datetime - - mock_market_state = MarketState( - symbol='ETH/USDT', - timestamp=datetime.now(), - prices={'1s': 2500.0}, - features={'1s': np.array([])}, - volatility=0.02, - volume=1000.0, - trend_strength=0.5, - market_regime='normal', - universal_data=None - ) - - cnn_adjustment = orchestrator._get_cnn_threshold_adjustment( - 'ETH/USDT', 'BUY', mock_market_state - ) - - logger.info(f"CNN threshold adjustment: {cnn_adjustment:.3f}") - assert 0.0 <= cnn_adjustment <= 0.1, "CNN adjustment should be reasonable" - - logger.info("โœ… CNN integration working correctly") - return True - - except Exception as e: - logger.error(f"Error testing CNN integration: {e}") - return False - -def run_all_tests(): - """Run all enhanced pivot RL system tests""" - logger.info("๐Ÿš€ Starting Enhanced Pivot RL System Tests") - - tests = [ - test_enhanced_pivot_thresholds, - test_pivot_reward_calculation, - test_confidence_adjustment, - test_dynamic_threshold_updates, - test_cnn_integration - ] - - passed = 0 - total = len(tests) - - for test_func in tests: - try: - if test_func(): - passed += 1 - logger.info(f"โœ… {test_func.__name__} PASSED") - else: - logger.error(f"โŒ {test_func.__name__} FAILED") - except Exception as e: - logger.error(f"โŒ {test_func.__name__} ERROR: {e}") - - logger.info(f"\n๐Ÿ“Š Test Results: {passed}/{total} tests passed") - - if passed == total: - logger.info("๐ŸŽ‰ All Enhanced Pivot RL System tests PASSED!") - return True - else: - logger.error(f"โš ๏ธ {total - passed} tests FAILED") - return False - -if __name__ == "__main__": - success = run_all_tests() - - if success: - logger.info("\n๐Ÿ”ฅ Enhanced Pivot RL System is ready for deployment!") - logger.info("Key improvements:") - logger.info(" โœ… Higher entry threshold than exit threshold") - logger.info(" โœ… Pivot-based reward calculation") - logger.info(" โœ… CNN predictions for early pivot detection") - logger.info(" โœ… Rewards for staying uninvested when uncertain") - logger.info(" โœ… Confidence-based reward adjustments") - logger.info(" โœ… Dynamic threshold learning from performance") - else: - logger.error("\nโŒ Enhanced Pivot RL System has issues that need fixing") - - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/test_enhanced_rl_fix.py b/tests/test_enhanced_rl_fix.py deleted file mode 100644 index 64b12b5..0000000 --- a/tests/test_enhanced_rl_fix.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced RL Fix - Verify comprehensive state building and reward calculation -""" - -import sys -from pathlib import Path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -def test_enhanced_orchestrator(): - """Test enhanced orchestrator methods""" - print("=== TESTING ENHANCED RL FIXES ===") - - try: - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - from core.data_provider import DataProvider - print("โœ“ Enhanced orchestrator imported successfully") - - # Create orchestrator with enhanced RL enabled - dp = DataProvider() - eo = EnhancedTradingOrchestrator( - data_provider=dp, - enhanced_rl_training=True, - symbols=['ETH/USDT', 'BTC/USDT'] - ) - print("โœ“ Enhanced orchestrator created") - - # Test method availability - methods = ['build_comprehensive_rl_state', 'calculate_enhanced_pivot_reward', '_get_symbol_correlation'] - print("\nMethod availability:") - for method in methods: - available = hasattr(eo, method) - print(f" {method}: {'โœ“' if available else 'โœ—'}") - - # Test comprehensive state building - print("\nTesting comprehensive state building...") - state = eo.build_comprehensive_rl_state('ETH/USDT') - if state is not None: - print(f"โœ“ Comprehensive state built: {len(state)} features") - print(f" State type: {type(state)}") - print(f" State shape: {state.shape if hasattr(state, 'shape') else 'No shape'}") - else: - print("โœ— Comprehensive state returned None") - - # Debug why state is None - print("\nDEBUGGING STATE BUILDING...") - print(f" Williams enabled: {hasattr(eo, 'williams_enabled')}") - print(f" COB integration active: {hasattr(eo, 'cob_integration_active')}") - print(f" Enhanced RL training: {getattr(eo, 'enhanced_rl_training', 'Not set')}") - - # Test enhanced reward calculation - print("\nTesting enhanced reward calculation...") - trade_decision = { - 'action': 'BUY', - 'confidence': 0.75, - 'price': 2500.0, - 'timestamp': '2023-01-01 00:00:00' - } - trade_outcome = { - 'net_pnl': 50.0, - 'exit_price': 2550.0, - 'duration': '00:15:00' - } - market_data = {'symbol': 'ETH/USDT'} - - try: - reward = eo.calculate_enhanced_pivot_reward(trade_decision, market_data, trade_outcome) - print(f"โœ“ Enhanced reward calculated: {reward}") - except Exception as e: - print(f"โœ— Enhanced reward failed: {e}") - import traceback - traceback.print_exc() - - print("\n=== TEST COMPLETE ===") - - except Exception as e: - print(f"โœ— Test failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_enhanced_orchestrator() \ No newline at end of file diff --git a/tests/test_enhanced_williams_cnn.py b/tests/test_enhanced_williams_cnn.py deleted file mode 100644 index 1da5152..0000000 --- a/tests/test_enhanced_williams_cnn.py +++ /dev/null @@ -1,346 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Williams Market Structure with CNN Integration - -This script demonstrates the multi-timeframe, multi-symbol CNN-enabled -Williams Market Structure that predicts pivot points using TrainingDataPacket. - -Features tested: -- Multi-timeframe data integration (1s, 1m, 1h) -- Multi-symbol support (ETH, BTC) -- Tick data aggregation -- 1h-based normalization strategy -- Multi-level pivot prediction (5 levels, type + price) -""" - -import numpy as np -import logging -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Any -from dataclasses import dataclass - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Mock TrainingDataPacket for testing -@dataclass -class MockTrainingDataPacket: - """Mock TrainingDataPacket for testing CNN integration""" - timestamp: datetime - symbol: str - tick_cache: List[Dict[str, Any]] - one_second_bars: List[Dict[str, Any]] - multi_timeframe_data: Dict[str, List[Dict[str, Any]]] - cnn_features: Optional[Dict[str, np.ndarray]] = None - cnn_predictions: Optional[Dict[str, np.ndarray]] = None - market_state: Optional[Any] = None - universal_stream: Optional[Any] = None - -class MockTrainingDataProvider: - """Mock provider that supplies TrainingDataPacket for testing""" - - def __init__(self): - self.training_data_buffer = [] - self._generate_mock_data() - - def _generate_mock_data(self): - """Generate comprehensive mock market data""" - current_time = datetime.now() - - # Generate ETH data for different timeframes - eth_1s_data = self._generate_ohlcv_data(2400.0, 900, '1s', current_time) # 15 min of 1s data - eth_1m_data = self._generate_ohlcv_data(2400.0, 900, '1m', current_time) # 15 hours of 1m data - eth_1h_data = self._generate_ohlcv_data(2400.0, 24, '1h', current_time) # 24 hours of 1h data - - # Generate BTC data - btc_1s_data = self._generate_ohlcv_data(45000.0, 300, '1s', current_time) # 5 min of 1s data - - # Generate tick data - tick_data = self._generate_tick_data(current_time) - - # Create comprehensive TrainingDataPacket - training_packet = MockTrainingDataPacket( - timestamp=current_time, - symbol='ETH/USDT', - tick_cache=tick_data, - one_second_bars=eth_1s_data[-300:], # Last 5 minutes - multi_timeframe_data={ - 'ETH/USDT': { - '1s': eth_1s_data, - '1m': eth_1m_data, - '1h': eth_1h_data - }, - 'BTC/USDT': { - '1s': btc_1s_data - } - } - ) - - self.training_data_buffer.append(training_packet) - logger.info(f"Generated mock training data: {len(eth_1s_data)} 1s bars, {len(eth_1m_data)} 1m bars, {len(eth_1h_data)} 1h bars") - logger.info(f"ETH 1h price range: {min(bar['low'] for bar in eth_1h_data):.2f} - {max(bar['high'] for bar in eth_1h_data):.2f}") - - def _generate_ohlcv_data(self, base_price: float, count: int, timeframe: str, end_time: datetime) -> List[Dict[str, Any]]: - """Generate realistic OHLCV data with indicators""" - data = [] - - # Calculate time delta based on timeframe - if timeframe == '1s': - delta = timedelta(seconds=1) - elif timeframe == '1m': - delta = timedelta(minutes=1) - elif timeframe == '1h': - delta = timedelta(hours=1) - else: - delta = timedelta(minutes=1) - - current_price = base_price - - for i in range(count): - timestamp = end_time - delta * (count - i - 1) - - # Generate realistic price movement - price_change = np.random.normal(0, base_price * 0.001) # 0.1% volatility - current_price = max(current_price + price_change, base_price * 0.8) # Floor at 80% of base - - # Generate OHLCV - open_price = current_price - high_price = open_price * (1 + abs(np.random.normal(0, 0.002))) - low_price = open_price * (1 - abs(np.random.normal(0, 0.002))) - close_price = low_price + (high_price - low_price) * np.random.random() - volume = np.random.exponential(1000) - - current_price = close_price - - # Calculate basic indicators (placeholders) - sma_20 = close_price * (1 + np.random.normal(0, 0.001)) - ema_20 = close_price * (1 + np.random.normal(0, 0.0005)) - rsi_14 = 30 + np.random.random() * 40 # RSI between 30-70 - macd = np.random.normal(0, 0.1) - bb_upper = high_price * 1.02 - - bar = { - 'timestamp': timestamp, - 'open': open_price, - 'high': high_price, - 'low': low_price, - 'close': close_price, - 'volume': volume, - 'sma_20': sma_20, - 'ema_20': ema_20, - 'rsi_14': rsi_14, - 'macd': macd, - 'bb_upper': bb_upper - } - data.append(bar) - - return data - - def _generate_tick_data(self, end_time: datetime) -> List[Dict[str, Any]]: - """Generate realistic tick data for last 5 minutes""" - ticks = [] - - # Generate ETH ticks - for i in range(300): # 5 minutes * 60 seconds - tick_time = end_time - timedelta(seconds=300 - i) - - # 2-5 ticks per second - ticks_per_second = np.random.randint(2, 6) - - for j in range(ticks_per_second): - tick = { - 'symbol': 'ETH/USDT', - 'timestamp': tick_time + timedelta(milliseconds=j * 200), - 'price': 2400.0 + np.random.normal(0, 5), - 'volume': np.random.exponential(0.5), - 'quantity': np.random.exponential(1.0), - 'side': 'buy' if np.random.random() > 0.5 else 'sell' - } - ticks.append(tick) - - # Generate BTC ticks - for i in range(300): # 5 minutes * 60 seconds - tick_time = end_time - timedelta(seconds=300 - i) - - ticks_per_second = np.random.randint(1, 4) - - for j in range(ticks_per_second): - tick = { - 'symbol': 'BTC/USDT', - 'timestamp': tick_time + timedelta(milliseconds=j * 300), - 'price': 45000.0 + np.random.normal(0, 100), - 'volume': np.random.exponential(0.1), - 'quantity': np.random.exponential(0.5), - 'side': 'buy' if np.random.random() > 0.5 else 'sell' - } - ticks.append(tick) - - return ticks - - def get_latest_training_data(self): - """Return the latest TrainingDataPacket""" - return self.training_data_buffer[-1] if self.training_data_buffer else None - - -def test_enhanced_williams_cnn(): - """Test the enhanced Williams Market Structure with CNN integration""" - try: - from training.williams_market_structure import WilliamsMarketStructure, SwingType - - logger.info("=" * 80) - logger.info("TESTING ENHANCED WILLIAMS MARKET STRUCTURE WITH CNN INTEGRATION") - logger.info("=" * 80) - - # Create mock data provider - data_provider = MockTrainingDataProvider() - - # Initialize Williams Market Structure with CNN - williams = WilliamsMarketStructure( - swing_strengths=[2, 3, 5], # Reduced for testing - cnn_input_shape=(900, 50), # 900 timesteps, 50 features - cnn_output_size=10, # 5 levels * 2 outputs (type + price) - enable_cnn_feature=True, # Enable CNN features - training_data_provider=data_provider - ) - - logger.info(f"CNN enabled: {williams.enable_cnn_feature}") - logger.info(f"Training data provider available: {williams.training_data_provider is not None}") - - # Generate test OHLCV data for Williams calculation - test_ohlcv = generate_test_ohlcv_data() - logger.info(f"Generated test OHLCV data: {len(test_ohlcv)} bars") - - # Test Williams calculation with CNN integration - logger.info("\n" + "=" * 60) - logger.info("RUNNING WILLIAMS PIVOT CALCULATION WITH CNN INTEGRATION") - logger.info("=" * 60) - - structure_levels = williams.calculate_recursive_pivot_points(test_ohlcv) - - # Display results - logger.info(f"\nWilliams Structure Analysis Results:") - logger.info(f"Calculated levels: {len(structure_levels)}") - - for level_key, level_data in structure_levels.items(): - swing_count = len(level_data.swing_points) - logger.info(f"{level_key}: {swing_count} swing points, " - f"trend: {level_data.trend_analysis.direction.value}, " - f"bias: {level_data.current_bias.value}") - - if swing_count > 0: - latest_swing = level_data.swing_points[-1] - logger.info(f" Latest swing: {latest_swing.swing_type.name} @ {latest_swing.price:.2f}") - - # Test CNN input preparation directly - logger.info("\n" + "=" * 60) - logger.info("TESTING CNN INPUT PREPARATION") - logger.info("=" * 60) - - if williams.enable_cnn_feature and structure_levels['level_0'].swing_points: - test_pivot = structure_levels['level_0'].swing_points[-1] - - logger.info(f"Testing CNN input for pivot: {test_pivot.swing_type.name} @ {test_pivot.price:.2f}") - - # Test input preparation - cnn_input = williams._prepare_cnn_input( - current_pivot=test_pivot, - ohlcv_data_context=test_ohlcv, - previous_pivot_details=None - ) - - logger.info(f"CNN input shape: {cnn_input.shape}") - logger.info(f"CNN input range: [{cnn_input.min():.4f}, {cnn_input.max():.4f}]") - logger.info(f"CNN input mean: {cnn_input.mean():.4f}, std: {cnn_input.std():.4f}") - - # Test ground truth preparation - if len(structure_levels['level_0'].swing_points) >= 2: - prev_pivot = structure_levels['level_0'].swing_points[-2] - current_pivot = structure_levels['level_0'].swing_points[-1] - - prev_details = {'pivot': prev_pivot} - ground_truth = williams._get_cnn_ground_truth(prev_details, current_pivot) - - logger.info(f"Ground truth shape: {ground_truth.shape}") - logger.info(f"Ground truth (first 4 values): {ground_truth[:4]}") - logger.info(f"Level 0 prediction: type={ground_truth[0]:.2f}, price={ground_truth[1]:.4f}") - - # Test normalization - logger.info("\n" + "=" * 60) - logger.info("TESTING 1H-BASED NORMALIZATION") - logger.info("=" * 60) - - training_packet = data_provider.get_latest_training_data() - if training_packet: - # Test normalization with sample data - sample_features = np.random.normal(2400, 50, (100, 10)) # ETH-like prices - - normalized = williams._normalize_features_by_1h_range(sample_features, training_packet) - - logger.info(f"Original features range: [{sample_features.min():.2f}, {sample_features.max():.2f}]") - logger.info(f"Normalized features range: [{normalized.min():.4f}, {normalized.max():.4f}]") - - # Check if 1h data is being used for normalization - eth_1h = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', []) - if eth_1h: - h1_prices = [] - for bar in eth_1h[-24:]: - h1_prices.extend([bar['open'], bar['high'], bar['low'], bar['close']]) - h1_range = max(h1_prices) - min(h1_prices) - logger.info(f"1h price range used for normalization: {h1_range:.2f}") - - logger.info("\n" + "=" * 80) - logger.info("ENHANCED WILLIAMS CNN INTEGRATION TEST COMPLETED SUCCESSFULLY") - logger.info("=" * 80) - - return True - - except ImportError as e: - logger.error(f"Import error - some dependencies missing: {e}") - logger.info("This is expected if TensorFlow or other dependencies are not installed") - return False - except Exception as e: - logger.error(f"Test failed with error: {e}", exc_info=True) - return False - - -def generate_test_ohlcv_data(bars=200, base_price=2400.0): - """Generate test OHLCV data for Williams calculation""" - data = [] - current_price = base_price - current_time = datetime.now() - - for i in range(bars): - timestamp = current_time - timedelta(seconds=bars - i) - - # Generate price movement - price_change = np.random.normal(0, base_price * 0.002) - current_price = max(current_price + price_change, base_price * 0.9) - - open_price = current_price - high_price = open_price * (1 + abs(np.random.normal(0, 0.003))) - low_price = open_price * (1 - abs(np.random.normal(0, 0.003))) - close_price = low_price + (high_price - low_price) * np.random.random() - volume = np.random.exponential(1000) - - current_price = close_price - - bar = [ - timestamp.timestamp(), - open_price, - high_price, - low_price, - close_price, - volume - ] - data.append(bar) - - return np.array(data) - - -if __name__ == "__main__": - success = test_enhanced_williams_cnn() - if success: - print("\nโœ… All tests passed! Enhanced Williams CNN integration is working.") - else: - print("\nโŒ Some tests failed. Check logs for details.") \ No newline at end of file diff --git a/tests/test_essential.py b/tests/test_essential.py deleted file mode 100644 index 05222be..0000000 --- a/tests/test_essential.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -""" -Essential Test Suite - Core functionality tests - -This file contains the most important tests to verify core functionality: -- Data loading and processing -- Basic model operations -- Trading signal generation -- Critical utility functions -""" - -import sys -import os -import unittest -import logging -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -logger = logging.getLogger(__name__) - -class TestEssentialFunctionality(unittest.TestCase): - """Essential tests for core trading system functionality""" - - def test_imports(self): - """Test that all critical modules can be imported""" - try: - from core.config import get_config - from core.data_provider import DataProvider - from utils.model_utils import robust_save, robust_load - logger.info("โœ… All critical imports successful") - except ImportError as e: - self.fail(f"Critical import failed: {e}") - - def test_config_loading(self): - """Test configuration loading""" - try: - from core.config import get_config - config = get_config() - self.assertIsNotNone(config, "Config should be loaded") - logger.info("โœ… Configuration loading successful") - except Exception as e: - self.fail(f"Config loading failed: {e}") - - def test_data_provider_initialization(self): - """Test DataProvider can be initialized""" - try: - from core.data_provider import DataProvider - data_provider = DataProvider(['ETH/USDT'], ['1m']) - self.assertIsNotNone(data_provider, "DataProvider should initialize") - logger.info("โœ… DataProvider initialization successful") - except Exception as e: - self.fail(f"DataProvider initialization failed: {e}") - - def test_model_utils(self): - """Test model utility functions""" - try: - from utils.model_utils import get_model_info - import tempfile - - # Test with non-existent file - info = get_model_info("non_existent_file.pt") - self.assertFalse(info['exists'], "Should report file doesn't exist") - - logger.info("โœ… Model utils test successful") - except Exception as e: - self.fail(f"Model utils test failed: {e}") - - def test_signal_generation_logic(self): - """Test basic signal generation logic""" - import numpy as np - - # Test signal distribution calculation - predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY - - buy_count = np.sum(predictions == 2) - sell_count = np.sum(predictions == 0) - hold_count = np.sum(predictions == 1) - total = len(predictions) - - distribution = { - "BUY": buy_count / total, - "SELL": sell_count / total, - "HOLD": hold_count / total - } - - # Verify calculations - self.assertAlmostEqual(distribution["BUY"], 0.3, places=1) - self.assertAlmostEqual(distribution["SELL"], 0.3, places=1) - self.assertAlmostEqual(distribution["HOLD"], 0.4, places=1) - self.assertAlmostEqual(sum(distribution.values()), 1.0, places=1) - - logger.info("โœ… Signal generation logic test successful") - -def run_essential_tests(): - """Run essential tests only""" - suite = unittest.TestLoader().loadTestsFromTestCase(TestEssentialFunctionality) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(suite) - return result.wasSuccessful() - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - logger.info("Running essential functionality tests...") - - success = run_essential_tests() - - if success: - logger.info("โœ… All essential tests passed!") - sys.exit(0) - else: - logger.error("โŒ Essential tests failed!") - sys.exit(1) \ No newline at end of file diff --git a/tests/test_extrema_training_enhanced.py b/tests/test_extrema_training_enhanced.py deleted file mode 100644 index 69c7dc8..0000000 --- a/tests/test_extrema_training_enhanced.py +++ /dev/null @@ -1,508 +0,0 @@ -#!/usr/bin/env python3 -""" -Enhanced Extrema Training Test Suite - -Tests the complete extrema training system including: -1. 200-candle 1m context data loading -2. Local extrema detection (bottoms and tops) -3. Training on not-so-perfect opportunities -4. Dashboard integration with extrema information -5. Reusable functionality across different dashboards - -This test suite verifies all components work together correctly. -""" - -import sys -import os -import asyncio -import logging -import numpy as np -import pandas as pd -from datetime import datetime, timedelta -from typing import Dict, List, Any -import time - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_extrema_trainer_initialization(): - """Test 1: Extrema trainer initialization and basic functionality""" - print("\n" + "="*60) - print("TEST 1: Extrema Trainer Initialization") - print("="*60) - - try: - from core.extrema_trainer import ExtremaTrainer - from core.data_provider import DataProvider - - # Initialize components - data_provider = DataProvider() - symbols = ['ETHUSDT', 'BTCUSDT'] - - # Create extrema trainer - extrema_trainer = ExtremaTrainer( - data_provider=data_provider, - symbols=symbols, - window_size=10 - ) - - # Verify initialization - assert extrema_trainer.symbols == symbols - assert extrema_trainer.window_size == 10 - assert len(extrema_trainer.detected_extrema) == len(symbols) - assert len(extrema_trainer.context_data) == len(symbols) - - print("โœ… Extrema trainer initialized successfully") - print(f" - Symbols: {symbols}") - print(f" - Window size: {extrema_trainer.window_size}") - print(f" - Context data containers: {len(extrema_trainer.context_data)}") - print(f" - Extrema containers: {len(extrema_trainer.detected_extrema)}") - - return True, extrema_trainer - - except Exception as e: - print(f"โŒ Extrema trainer initialization failed: {e}") - return False, None - -def test_context_data_loading(extrema_trainer): - """Test 2: 200-candle 1m context data loading""" - print("\n" + "="*60) - print("TEST 2: 200-Candle 1m Context Data Loading") - print("="*60) - - try: - # Initialize context data - start_time = time.time() - results = extrema_trainer.initialize_context_data() - load_time = time.time() - start_time - - # Verify results - successful_loads = sum(1 for success in results.values() if success) - total_symbols = len(extrema_trainer.symbols) - - print(f"โœ… Context data loading completed in {load_time:.2f} seconds") - print(f" - Success rate: {successful_loads}/{total_symbols} symbols") - - # Check context data details - for symbol in extrema_trainer.symbols: - context = extrema_trainer.context_data[symbol] - candles_loaded = len(context.candles) - features_available = context.features is not None - - print(f" - {symbol}: {candles_loaded} candles, features: {'โœ…' if features_available else 'โŒ'}") - - if features_available: - print(f" Features shape: {context.features.shape}") - - # Test context feature retrieval - for symbol in extrema_trainer.symbols: - features = extrema_trainer.get_context_features_for_model(symbol) - if features is not None: - print(f" - {symbol} model features: {features.shape}") - else: - print(f" - {symbol} model features: Not available") - - return successful_loads > 0 - - except Exception as e: - print(f"โŒ Context data loading failed: {e}") - return False - -def test_extrema_detection(extrema_trainer): - """Test 3: Local extrema detection (bottoms and tops)""" - print("\n" + "="*60) - print("TEST 3: Local Extrema Detection") - print("="*60) - - try: - # Run batch extrema detection - start_time = time.time() - detection_results = extrema_trainer.run_batch_detection() - detection_time = time.time() - start_time - - # Analyze results - total_extrema = sum(len(extrema_list) for extrema_list in detection_results.values()) - - print(f"โœ… Extrema detection completed in {detection_time:.2f} seconds") - print(f" - Total extrema detected: {total_extrema}") - - # Detailed breakdown by symbol - for symbol, extrema_list in detection_results.items(): - if extrema_list: - bottoms = len([e for e in extrema_list if e.extrema_type == 'bottom']) - tops = len([e for e in extrema_list if e.extrema_type == 'top']) - avg_confidence = np.mean([e.confidence for e in extrema_list]) - - print(f" - {symbol}: {len(extrema_list)} extrema (bottoms: {bottoms}, tops: {tops})") - print(f" Average confidence: {avg_confidence:.3f}") - - # Show recent extrema details - for extrema in extrema_list[-2:]: # Last 2 extrema - print(f" {extrema.extrema_type.upper()} @ ${extrema.price:.2f} " - f"(confidence: {extrema.confidence:.3f}, action: {extrema.optimal_action})") - - # Test perfect moves for CNN - perfect_moves = extrema_trainer.get_perfect_moves_for_cnn(count=20) - print(f" - Perfect moves for CNN training: {len(perfect_moves)}") - - if perfect_moves: - for move in perfect_moves[:3]: # Show first 3 - print(f" {move['optimal_action']} {move['symbol']} @ {move['timestamp'].strftime('%H:%M:%S')} " - f"(outcome: {move['actual_outcome']:.3f}, confidence: {move['confidence_should_have_been']:.3f})") - - return total_extrema > 0 - - except Exception as e: - print(f"โŒ Extrema detection failed: {e}") - return False - -def test_context_data_updates(extrema_trainer): - """Test 4: Context data updates and continuous extrema detection""" - print("\n" + "="*60) - print("TEST 4: Context Data Updates and Continuous Detection") - print("="*60) - - try: - # Test single symbol update - symbol = extrema_trainer.symbols[0] - - print(f"Testing context update for {symbol}...") - start_time = time.time() - update_results = extrema_trainer.update_context_data(symbol) - update_time = time.time() - start_time - - print(f"โœ… Context update completed in {update_time:.2f} seconds") - print(f" - Update result for {symbol}: {'โœ…' if update_results.get(symbol, False) else 'โŒ'}") - - # Test all symbols update - print("Testing context update for all symbols...") - start_time = time.time() - all_update_results = extrema_trainer.update_context_data() - all_update_time = time.time() - start_time - - successful_updates = sum(1 for success in all_update_results.values() if success) - - print(f"โœ… All symbols update completed in {all_update_time:.2f} seconds") - print(f" - Success rate: {successful_updates}/{len(extrema_trainer.symbols)} symbols") - - # Check for new extrema after updates - new_extrema = extrema_trainer.run_batch_detection() - new_total = sum(len(extrema_list) for extrema_list in new_extrema.values()) - - print(f" - New extrema detected after update: {new_total}") - - return successful_updates > 0 - - except Exception as e: - print(f"โŒ Context data updates failed: {e}") - return False - -def test_extrema_stats_and_training_data(extrema_trainer): - """Test 5: Extrema statistics and training data retrieval""" - print("\n" + "="*60) - print("TEST 5: Extrema Statistics and Training Data") - print("="*60) - - try: - # Get comprehensive stats - stats = extrema_trainer.get_extrema_stats() - - print("โœ… Extrema statistics retrieved successfully") - print(f" - Total extrema detected: {stats.get('total_extrema_detected', 0)}") - print(f" - Training queue size: {stats.get('training_queue_size', 0)}") - print(f" - Window size: {stats.get('window_size', 0)}") - - # Confidence thresholds - thresholds = stats.get('confidence_thresholds', {}) - print(f" - Confidence thresholds: min={thresholds.get('min', 0):.2f}, max={thresholds.get('max', 0):.2f}") - - # Context data status - context_status = stats.get('context_data_status', {}) - for symbol, status in context_status.items(): - candles = status.get('candles_loaded', 0) - features = status.get('features_available', False) - last_update = status.get('last_update', 'Unknown') - print(f" - {symbol}: {candles} candles, features: {'โœ…' if features else 'โŒ'}, updated: {last_update}") - - # Recent extrema breakdown - recent_extrema = stats.get('recent_extrema', {}) - if recent_extrema: - print(f" - Recent extrema: {recent_extrema.get('bottoms', 0)} bottoms, {recent_extrema.get('tops', 0)} tops") - print(f" - Average confidence: {recent_extrema.get('avg_confidence', 0):.3f}") - print(f" - Average outcome: {recent_extrema.get('avg_outcome', 0):.3f}") - - # Test training data retrieval - training_data = extrema_trainer.get_extrema_training_data(count=10, min_confidence=0.4) - print(f" - Training data (min confidence 0.4): {len(training_data)} cases") - - if training_data: - high_confidence_cases = len([case for case in training_data if case.confidence > 0.7]) - print(f" - High confidence cases (>0.7): {high_confidence_cases}") - - return True - - except Exception as e: - print(f"โŒ Extrema statistics retrieval failed: {e}") - return False - -def test_enhanced_orchestrator_integration(): - """Test 6: Enhanced orchestrator integration with extrema trainer""" - print("\n" + "="*60) - print("TEST 6: Enhanced Orchestrator Integration") - print("="*60) - - try: - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - from core.data_provider import DataProvider - - # Initialize orchestrator (should include extrema trainer) - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Verify extrema trainer integration - assert hasattr(orchestrator, 'extrema_trainer') - assert orchestrator.extrema_trainer is not None - - print("โœ… Enhanced orchestrator initialized with extrema trainer") - print(f" - Extrema trainer symbols: {orchestrator.extrema_trainer.symbols}") - - # Test extrema stats retrieval through orchestrator - extrema_stats = orchestrator.get_extrema_stats() - print(f" - Extrema stats available: {'โœ…' if extrema_stats else 'โŒ'}") - - if extrema_stats: - print(f" - Total extrema: {extrema_stats.get('total_extrema_detected', 0)}") - print(f" - Training queue: {extrema_stats.get('training_queue_size', 0)}") - - # Test context features retrieval - for symbol in orchestrator.symbols[:2]: # Test first 2 symbols - context_features = orchestrator.get_context_features_for_model(symbol) - if context_features is not None: - print(f" - {symbol} context features: {context_features.shape}") - else: - print(f" - {symbol} context features: Not available") - - # Test perfect moves for CNN - perfect_moves = orchestrator.get_perfect_moves_for_cnn(count=10) - print(f" - Perfect moves for CNN: {len(perfect_moves)}") - - return True, orchestrator - - except Exception as e: - print(f"โŒ Enhanced orchestrator integration failed: {e}") - return False, None - -def test_dashboard_integration(orchestrator): - """Test 7: Dashboard integration with extrema information""" - print("\n" + "="*60) - print("TEST 7: Dashboard Integration") - print("="*60) - - try: - from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard - - # Initialize dashboard with enhanced orchestrator - dashboard = RealTimeScalpingDashboard(orchestrator=orchestrator) - - print("โœ… Dashboard initialized with enhanced orchestrator") - - # Test sensitivity learning info (should include extrema stats) - sensitivity_info = dashboard._get_sensitivity_learning_info() - - print("โœ… Sensitivity learning info retrieved") - print(f" - Info structure: {list(sensitivity_info.keys())}") - - # Check for extrema information - if 'extrema' in sensitivity_info: - extrema_info = sensitivity_info['extrema'] - print(f" - Extrema info available: โœ…") - print(f" - Total extrema detected: {extrema_info.get('total_extrema_detected', 0)}") - print(f" - Training queue size: {extrema_info.get('training_queue_size', 0)}") - - recent_extrema = extrema_info.get('recent_extrema', {}) - if recent_extrema: - print(f" - Recent bottoms: {recent_extrema.get('bottoms', 0)}") - print(f" - Recent tops: {recent_extrema.get('tops', 0)}") - print(f" - Average confidence: {recent_extrema.get('avg_confidence', 0):.3f}") - - # Check for context data information - if 'context_data' in sensitivity_info: - context_info = sensitivity_info['context_data'] - print(f" - Context data info available: โœ…") - print(f" - Symbols with context: {len(context_info)}") - - for symbol, status in list(context_info.items())[:2]: # Show first 2 - candles = status.get('candles_loaded', 0) - features = status.get('features_available', False) - print(f" - {symbol}: {candles} candles, features: {'โœ…' if features else 'โŒ'}") - - # Test model training status creation - try: - training_status = dashboard._create_model_training_status() - print("โœ… Model training status created successfully") - print(f" - Status type: {type(training_status)}") - except Exception as e: - print(f"โš ๏ธ Model training status creation had issues: {e}") - - return True - - except Exception as e: - print(f"โŒ Dashboard integration failed: {e}") - return False - -def test_reusability_across_dashboards(): - """Test 8: Reusability of extrema trainer across different dashboards""" - print("\n" + "="*60) - print("TEST 8: Reusability Across Different Dashboards") - print("="*60) - - try: - from core.extrema_trainer import ExtremaTrainer - from core.data_provider import DataProvider - - # Create shared extrema trainer - data_provider = DataProvider() - shared_extrema_trainer = ExtremaTrainer( - data_provider=data_provider, - symbols=['ETHUSDT'], - window_size=8 # Different window size - ) - - # Initialize context data - shared_extrema_trainer.initialize_context_data() - - print("โœ… Shared extrema trainer created") - print(f" - Window size: {shared_extrema_trainer.window_size}") - print(f" - Symbols: {shared_extrema_trainer.symbols}") - - # Simulate usage by multiple dashboard types - dashboard_types = ['scalping', 'swing', 'analysis'] - - for dashboard_type in dashboard_types: - print(f"\n Testing {dashboard_type} dashboard usage:") - - # Get extrema stats (reusable method) - stats = shared_extrema_trainer.get_extrema_stats() - print(f" - {dashboard_type}: Extrema stats retrieved โœ…") - - # Get context features (reusable method) - features = shared_extrema_trainer.get_context_features_for_model('ETHUSDT') - if features is not None: - print(f" - {dashboard_type}: Context features available โœ… {features.shape}") - else: - print(f" - {dashboard_type}: Context features not available โŒ") - - # Get training data (reusable method) - training_data = shared_extrema_trainer.get_extrema_training_data(count=5) - print(f" - {dashboard_type}: Training data retrieved โœ… ({len(training_data)} cases)") - - # Get perfect moves (reusable method) - perfect_moves = shared_extrema_trainer.get_perfect_moves_for_cnn(count=5) - print(f" - {dashboard_type}: Perfect moves retrieved โœ… ({len(perfect_moves)} moves)") - - print("\nโœ… Extrema trainer successfully reused across different dashboard types") - - return True - - except Exception as e: - print(f"โŒ Reusability test failed: {e}") - return False - -def run_comprehensive_test_suite(): - """Run the complete test suite""" - print("๐Ÿš€ ENHANCED EXTREMA TRAINING TEST SUITE") - print("="*80) - print("Testing 200-candle context data, extrema detection, and dashboard integration") - print("="*80) - - test_results = [] - extrema_trainer = None - orchestrator = None - - # Test 1: Extrema trainer initialization - success, extrema_trainer = test_extrema_trainer_initialization() - test_results.append(("Extrema Trainer Initialization", success)) - - if success and extrema_trainer: - # Test 2: Context data loading - success = test_context_data_loading(extrema_trainer) - test_results.append(("200-Candle Context Data Loading", success)) - - # Test 3: Extrema detection - success = test_extrema_detection(extrema_trainer) - test_results.append(("Local Extrema Detection", success)) - - # Test 4: Context data updates - success = test_context_data_updates(extrema_trainer) - test_results.append(("Context Data Updates", success)) - - # Test 5: Stats and training data - success = test_extrema_stats_and_training_data(extrema_trainer) - test_results.append(("Extrema Stats and Training Data", success)) - - # Test 6: Enhanced orchestrator integration - success, orchestrator = test_enhanced_orchestrator_integration() - test_results.append(("Enhanced Orchestrator Integration", success)) - - if success and orchestrator: - # Test 7: Dashboard integration - success = test_dashboard_integration(orchestrator) - test_results.append(("Dashboard Integration", success)) - - # Test 8: Reusability - success = test_reusability_across_dashboards() - test_results.append(("Reusability Across Dashboards", success)) - - # Print final results - print("\n" + "="*80) - print("๐Ÿ TEST SUITE RESULTS") - print("="*80) - - passed = 0 - total = len(test_results) - - for test_name, success in test_results: - status = "โœ… PASSED" if success else "โŒ FAILED" - print(f"{test_name:<40} {status}") - if success: - passed += 1 - - print("="*80) - print(f"OVERALL RESULT: {passed}/{total} tests passed ({passed/total*100:.1f}%)") - - if passed == total: - print("๐ŸŽ‰ ALL TESTS PASSED! Enhanced extrema training system is working correctly.") - elif passed >= total * 0.8: - print("โœ… MOSTLY SUCCESSFUL! System is functional with minor issues.") - else: - print("โš ๏ธ SIGNIFICANT ISSUES DETECTED! Please review failed tests.") - - print("="*80) - - return passed, total - -if __name__ == "__main__": - try: - passed, total = run_comprehensive_test_suite() - - # Exit with appropriate code - if passed == total: - sys.exit(0) # Success - else: - sys.exit(1) # Some failures - - except KeyboardInterrupt: - print("\n\nโš ๏ธ Test suite interrupted by user") - sys.exit(2) - except Exception as e: - print(f"\n\nโŒ Test suite crashed: {e}") - import traceback - traceback.print_exc() - sys.exit(3) \ No newline at end of file diff --git a/tests/test_fee_sync.py b/tests/test_fee_sync.py deleted file mode 100644 index b56a106..0000000 --- a/tests/test_fee_sync.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -Test script for automatic fee synchronization with MEXC API - -This script demonstrates how the system can automatically sync trading fees -from the MEXC API to the local configuration file. -""" - -import os -import sys -import logging -from dotenv import load_dotenv - -# Add NN directory to path -sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) - -from NN.exchanges.mexc_interface import MEXCInterface -from core.config_sync import ConfigSynchronizer -from core.trading_executor import TradingExecutor - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_mexc_fee_retrieval(): - """Test retrieving fees directly from MEXC API""" - logger.info("=== Testing MEXC Fee Retrieval ===") - - # Load environment variables - load_dotenv() - - # Initialize MEXC interface - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key or not api_secret: - logger.error("MEXC API credentials not found in environment variables") - return None - - try: - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=False) - - # Test connection - if mexc.connect(): - logger.info("MEXC: Connection successful") - else: - logger.error("MEXC: Connection failed") - return None - - # Get trading fees - logger.info("MEXC: Fetching trading fees...") - fees = mexc.get_trading_fees() - - if fees: - logger.info(f"MEXC Trading Fees Retrieved:") - logger.info(f" Maker Rate: {fees.get('maker_rate', 0)*100:.3f}%") - logger.info(f" Taker Rate: {fees.get('taker_rate', 0)*100:.3f}%") - logger.info(f" Source: {fees.get('source', 'unknown')}") - - if fees.get('source') == 'mexc_api': - logger.info(f" Raw Commission Rates:") - logger.info(f" Maker: {fees.get('maker_commission', 0)} basis points") - logger.info(f" Taker: {fees.get('taker_commission', 0)} basis points") - else: - logger.warning("Using fallback fee values - API may not be working") - else: - logger.error("Failed to retrieve trading fees") - - return fees - - except Exception as e: - logger.error(f"Error testing MEXC fee retrieval: {e}") - return None - -def test_config_synchronization(): - """Test automatic config synchronization""" - logger.info("\n=== Testing Config Synchronization ===") - - # Load environment variables - load_dotenv() - - try: - # Initialize MEXC interface - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key or not api_secret: - logger.error("MEXC API credentials not found") - return False - - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=False) - - # Initialize config synchronizer - config_sync = ConfigSynchronizer( - config_path="config.yaml", - mexc_interface=mexc - ) - - # Get current sync status - logger.info("Current sync status:") - status = config_sync.get_sync_status() - for key, value in status.items(): - if key != 'latest_sync_result': - logger.info(f" {key}: {value}") - - # Perform manual sync - logger.info("\nPerforming manual fee synchronization...") - sync_result = config_sync.sync_trading_fees(force=True) - - logger.info(f"Sync Result:") - logger.info(f" Status: {sync_result.get('status')}") - logger.info(f" Changes Made: {sync_result.get('changes_made', False)}") - - if sync_result.get('changes'): - logger.info(" Fee Changes:") - for fee_type, change in sync_result['changes'].items(): - logger.info(f" {fee_type}: {change['old']:.6f} -> {change['new']:.6f}") - - if sync_result.get('errors'): - logger.warning(f" Errors: {sync_result['errors']}") - - # Test auto-sync - logger.info("\nTesting auto-sync...") - auto_sync_success = config_sync.auto_sync_fees() - logger.info(f"Auto-sync result: {'Success' if auto_sync_success else 'Failed/Skipped'}") - - return sync_result.get('status') in ['success', 'no_changes'] - - except Exception as e: - logger.error(f"Error testing config synchronization: {e}") - return False - -def test_trading_executor_integration(): - """Test fee sync integration in TradingExecutor""" - logger.info("\n=== Testing TradingExecutor Integration ===") - - try: - # Initialize trading executor (this should trigger automatic fee sync) - logger.info("Initializing TradingExecutor with auto fee sync...") - executor = TradingExecutor("config.yaml") - - # Check if config synchronizer was initialized - if hasattr(executor, 'config_synchronizer') and executor.config_synchronizer: - logger.info("Config synchronizer successfully initialized") - - # Get sync status - sync_status = executor.get_fee_sync_status() - logger.info("Fee sync status:") - for key, value in sync_status.items(): - if key not in ['latest_sync_result']: - logger.info(f" {key}: {value}") - - # Test manual sync through executor - logger.info("\nTesting manual sync through TradingExecutor...") - manual_sync = executor.sync_fees_with_api(force=True) - logger.info(f"Manual sync result: {manual_sync.get('status')}") - - # Test auto sync - logger.info("Testing auto sync...") - auto_sync = executor.auto_sync_fees_if_needed() - logger.info(f"Auto sync result: {'Success' if auto_sync else 'Skipped/Failed'}") - - return True - else: - logger.error("Config synchronizer not initialized in TradingExecutor") - return False - - except Exception as e: - logger.error(f"Error testing TradingExecutor integration: {e}") - return False - -def main(): - """Run all tests""" - logger.info("Starting Fee Synchronization Tests") - logger.info("=" * 50) - - # Test 1: Direct API fee retrieval - fees = test_mexc_fee_retrieval() - - # Test 2: Config synchronization - if fees: - sync_success = test_config_synchronization() - else: - logger.warning("Skipping config sync test due to API failure") - sync_success = False - - # Test 3: TradingExecutor integration - if sync_success: - integration_success = test_trading_executor_integration() - else: - logger.warning("Skipping TradingExecutor test due to sync failure") - integration_success = False - - # Summary - logger.info("\n" + "=" * 50) - logger.info("TEST SUMMARY:") - logger.info(f" MEXC API Fee Retrieval: {'PASS' if fees else 'FAIL'}") - logger.info(f" Config Synchronization: {'PASS' if sync_success else 'FAIL'}") - logger.info(f" TradingExecutor Integration: {'PASS' if integration_success else 'FAIL'}") - - if fees and sync_success and integration_success: - logger.info("\nALL TESTS PASSED! Fee synchronization is working correctly.") - logger.info("Your system will now automatically sync trading fees from MEXC API.") - else: - logger.warning("\nSome tests failed. Check the logs above for details.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_final_fixes.py b/tests/test_final_fixes.py deleted file mode 100644 index 24dc5e5..0000000 --- a/tests/test_final_fixes.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -""" -Final Test - Verify Enhanced Orchestrator Methods Work -""" - -import sys -from pathlib import Path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -def test_final_fixes(): - """Test that the enhanced orchestrator methods are working""" - print("=" * 60) - print("FINAL TEST - ENHANCED RL PIPELINE FIXES") - print("=" * 60) - - try: - # Import and test basic orchestrator - from core.orchestrator import TradingOrchestrator - from core.data_provider import DataProvider - - print("โœ“ Imports successful") - - # Create orchestrator - dp = DataProvider() - orch = TradingOrchestrator(dp) - print("โœ“ TradingOrchestrator created") - - # Test enhanced methods - methods = ['build_comprehensive_rl_state', 'calculate_enhanced_pivot_reward'] - print("\nTesting enhanced methods:") - - for method in methods: - has_method = hasattr(orch, method) - print(f" {method}: {'โœ“' if has_method else 'โœ—'}") - - # Test comprehensive RL state building - print("\nTesting comprehensive RL state building:") - state = orch.build_comprehensive_rl_state('ETH/USDT') - if state and len(state) >= 13000: - print(f"โœ… Comprehensive RL state: {len(state)} features (AUDIT FIXED)") - else: - print(f"โŒ Comprehensive RL state: {len(state) if state else 0} features") - - # Test enhanced reward calculation - print("\nTesting enhanced pivot reward:") - mock_trade_outcome = {'net_pnl': 25.0, 'hold_time_seconds': 300} - mock_market_data = {'current_price': 3500.0, 'trend_strength': 0.8, 'volatility': 0.1} - mock_trade_decision = {'price': 3495.0} - - reward = orch.calculate_enhanced_pivot_reward( - mock_trade_decision, - mock_market_data, - mock_trade_outcome - ) - print(f"โœ… Enhanced pivot reward: {reward:.4f}") - - # Test dashboard integration - print("\nTesting dashboard integration:") - from web.clean_dashboard import CleanTradingDashboard as TradingDashboard - - # Create dashboard with basic orchestrator (should work now) - dashboard = TradingDashboard(data_provider=dp, orchestrator=orch) - print("โœ“ Dashboard created with enhanced orchestrator") - - # Test dashboard can access enhanced methods - dashboard_has_enhanced = hasattr(dashboard.orchestrator, 'build_comprehensive_rl_state') - print(f" Dashboard has enhanced methods: {'โœ“' if dashboard_has_enhanced else 'โœ—'}") - - if dashboard_has_enhanced: - dashboard_state = dashboard.orchestrator.build_comprehensive_rl_state('ETH/USDT') - print(f" Dashboard comprehensive state: {len(dashboard_state) if dashboard_state else 0} features") - - print("\n" + "=" * 60) - print("๐ŸŽ‰ COMPREHENSIVE RL TRAINING PIPELINE FIXES COMPLETE!") - print("=" * 60) - print("โœ… AUDIT ISSUE #1: INPUT DATA GAP FIXED") - print(" - Comprehensive RL state: 13,400+ features") - print(" - ETH tick data, multi-timeframe OHLCV, BTC reference") - print(" - CNN features, pivot analysis, microstructure") - print("") - print("โœ… AUDIT ISSUE #2: ENHANCED REWARD CALCULATION FIXED") - print(" - Pivot-based reward system operational") - print(" - Market structure analysis integrated") - print(" - Trade execution quality assessment") - print("") - print("โœ… AUDIT ISSUE #3: ORCHESTRATOR INTEGRATION FIXED") - print(" - Dashboard can access enhanced methods") - print(" - No async/sync conflicts") - print(" - Real-time training data collection ready") - print("") - print("๐Ÿš€ READY FOR REAL-TIME TRAINING WITH RETROSPECTIVE SETUPS!") - print("=" * 60) - - return True - - except Exception as e: - print(f"\nโŒ ERROR: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = test_final_fixes() - if success: - print("\nโœ… All pipeline fixes verified and working!") - else: - print("\nโŒ Pipeline fixes need more work") \ No newline at end of file diff --git a/tests/test_free_orderbook_integration.py b/tests/test_free_orderbook_integration.py deleted file mode 100644 index abb9063..0000000 Binary files a/tests/test_free_orderbook_integration.py and /dev/null differ diff --git a/tests/test_gpu_training.py b/tests/test_gpu_training.py deleted file mode 100644 index c5d9c9b..0000000 --- a/tests/test_gpu_training.py +++ /dev/null @@ -1,301 +0,0 @@ -#!/usr/bin/env python3 -""" -Test GPU Training - Check if our models actually train and use GPU -""" - -import torch -import torch.nn as nn -import torch.optim as optim -import numpy as np -import time -import logging -from pathlib import Path -import sys - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_gpu_availability(): - """Test if GPU is available and working""" - logger.info("=== GPU AVAILABILITY TEST ===") - - print(f"PyTorch version: {torch.__version__}") - print(f"CUDA available: {torch.cuda.is_available()}") - - if torch.cuda.is_available(): - print(f"CUDA version: {torch.version.cuda}") - print(f"GPU count: {torch.cuda.device_count()}") - for i in range(torch.cuda.device_count()): - print(f"GPU {i}: {torch.cuda.get_device_name(i)}") - print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB") - - # Test GPU operations - try: - device = torch.device('cuda:0') - x = torch.randn(100, 100, device=device) - y = torch.randn(100, 100, device=device) - z = torch.mm(x, y) - print(f"โœ… GPU operations working: {z.device}") - return True - except Exception as e: - print(f"โŒ GPU operations failed: {e}") - return False - else: - print("โŒ No CUDA available") - return False - -def test_simple_training(): - """Test if a simple neural network actually trains""" - logger.info("=== SIMPLE TRAINING TEST ===") - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Using device: {device}") - - # Create a simple model - class SimpleNet(nn.Module): - def __init__(self): - super().__init__() - self.layers = nn.Sequential( - nn.Linear(10, 64), - nn.ReLU(), - nn.Linear(64, 32), - nn.ReLU(), - nn.Linear(32, 3) - ) - - def forward(self, x): - return self.layers(x) - - model = SimpleNet().to(device) - optimizer = optim.Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - # Generate some dummy data - X = torch.randn(1000, 10, device=device) - y = torch.randint(0, 3, (1000,), device=device) - - print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") - print(f"Data shape: {X.shape}, Labels shape: {y.shape}") - - # Training loop - initial_loss = None - losses = [] - - print("Training for 100 steps...") - start_time = time.time() - - for step in range(100): - # Forward pass - outputs = model(X) - loss = criterion(outputs, y) - - # Backward pass - optimizer.zero_grad() - loss.backward() - optimizer.step() - - loss_val = loss.item() - losses.append(loss_val) - - if step == 0: - initial_loss = loss_val - - if step % 20 == 0: - print(f"Step {step}: Loss = {loss_val:.4f}") - - end_time = time.time() - final_loss = losses[-1] - - print(f"Training completed in {end_time - start_time:.2f} seconds") - print(f"Initial loss: {initial_loss:.4f}") - print(f"Final loss: {final_loss:.4f}") - print(f"Loss reduction: {initial_loss - final_loss:.4f}") - - # Check if training actually happened - if final_loss < initial_loss * 0.9: # At least 10% reduction - print("โœ… Training is working - loss decreased significantly") - return True - else: - print("โŒ Training may not be working - loss didn't decrease much") - return False - -def test_our_models(): - """Test if our actual models can train""" - logger.info("=== OUR MODELS TEST ===") - - try: - # Test DQN Agent - from NN.models.dqn_agent import DQNAgent - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Testing DQN Agent on {device}") - - # Create agent - state_shape = (100,) # Simple state - agent = DQNAgent( - state_shape=state_shape, - n_actions=3, - learning_rate=0.001, - device=device - ) - - print(f"โœ… DQN Agent created successfully") - print(f" Device: {agent.device}") - print(f" Policy net device: {next(agent.policy_net.parameters()).device}") - - # Test training step - state = np.random.randn(100).astype(np.float32) - action = 1 - reward = 0.5 - next_state = np.random.randn(100).astype(np.float32) - done = False - - # Add experience and train - agent.remember(state, action, reward, next_state, done) - - # Add more experiences - for _ in range(200): # Need enough for batch - s = np.random.randn(100).astype(np.float32) - a = np.random.randint(0, 3) - r = np.random.randn() * 0.1 - ns = np.random.randn(100).astype(np.float32) - d = np.random.random() < 0.1 - agent.remember(s, a, r, ns, d) - - # Test training - print("Testing training step...") - initial_loss = None - for i in range(10): - loss = agent.replay() - if loss > 0: - if initial_loss is None: - initial_loss = loss - print(f" Step {i}: Loss = {loss:.4f}") - - if initial_loss is not None: - print("โœ… DQN training is working") - else: - print("โŒ DQN training returned no loss") - - return True - - except Exception as e: - print(f"โŒ Error testing our models: {e}") - import traceback - traceback.print_exc() - return False - -def test_cnn_model(): - """Test CNN model training""" - logger.info("=== CNN MODEL TEST ===") - - try: - from NN.models.enhanced_cnn import EnhancedCNN - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Testing Enhanced CNN on {device}") - - # Create model - state_dim = (3, 20, 26) # 3 timeframes, 20 window, 26 features - n_actions = 3 - - model = EnhancedCNN(state_dim, n_actions).to(device) - optimizer = optim.Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - print(f"โœ… Enhanced CNN created successfully") - print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") - - # Test forward pass - batch_size = 32 - x = torch.randn(batch_size, 3, 20, 26, device=device) - - print("Testing forward pass...") - outputs = model(x) - - if isinstance(outputs, tuple): - action_probs, extrema_pred, price_pred, features, advanced_pred = outputs - print(f"โœ… Forward pass successful") - print(f" Action probs shape: {action_probs.shape}") - print(f" Features shape: {features.shape}") - else: - print(f"โŒ Unexpected output format: {type(outputs)}") - return False - - # Test training step - y = torch.randint(0, 3, (batch_size,), device=device) - - print("Testing training step...") - loss = criterion(action_probs, y) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - print(f"โœ… CNN training step successful, loss: {loss.item():.4f}") - return True - - except Exception as e: - print(f"โŒ Error testing CNN model: {e}") - import traceback - traceback.print_exc() - return False - -def main(): - """Run all tests""" - print("=" * 60) - print("TESTING GPU TRAINING FUNCTIONALITY") - print("=" * 60) - - results = {} - - # Test 1: GPU availability - results['gpu'] = test_gpu_availability() - print() - - # Test 2: Simple training - results['simple_training'] = test_simple_training() - print() - - # Test 3: Our DQN models - results['dqn_models'] = test_our_models() - print() - - # Test 4: CNN models - results['cnn_models'] = test_cnn_model() - print() - - # Summary - print("=" * 60) - print("TEST RESULTS SUMMARY") - print("=" * 60) - - for test_name, passed in results.items(): - status = "โœ… PASS" if passed else "โŒ FAIL" - print(f"{test_name.upper()}: {status}") - - all_passed = all(results.values()) - - if all_passed: - print("\n๐ŸŽ‰ ALL TESTS PASSED - Your training should work with GPU!") - else: - print("\nโš ๏ธ SOME TESTS FAILED - Check the issues above") - - if not results['gpu']: - print(" โ†’ GPU not available or not working") - if not results['simple_training']: - print(" โ†’ Basic training loop not working") - if not results['dqn_models']: - print(" โ†’ DQN models have issues") - if not results['cnn_models']: - print(" โ†’ CNN models have issues") - - return 0 if all_passed else 1 - -if __name__ == "__main__": - exit_code = main() - sys.exit(exit_code) \ No newline at end of file diff --git a/tests/test_indicators_and_signals.py b/tests/test_indicators_and_signals.py deleted file mode 100644 index cc133d4..0000000 --- a/tests/test_indicators_and_signals.py +++ /dev/null @@ -1,402 +0,0 @@ -#!/usr/bin/env python3 -""" -Comprehensive Indicators and Signals Test Suite - -This module consolidates testing functionality for: -- Technical indicators (from test_indicators.py) -- Signal interpretation and processing (from test_signal_interpreter.py) -- Market data analysis -- Trading signal validation -""" - -import sys -import os -import unittest -import logging -import numpy as np -import tempfile -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging -from core.data_provider import DataProvider - -logger = logging.getLogger(__name__) - -class TestTechnicalIndicators(unittest.TestCase): - """Test suite for technical indicators functionality""" - - def setUp(self): - """Set up test fixtures""" - setup_logging() - self.data_provider = DataProvider(['ETH/USDT'], ['1h']) - - def test_indicator_calculation(self): - """Test that indicators are calculated correctly""" - logger.info("Testing technical indicators calculation...") - - try: - # Fetch data with indicators - df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100) - - self.assertIsNotNone(df, "Should fetch data successfully") - self.assertGreater(len(df), 0, "Should have data rows") - - # Check basic OHLCV columns - basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] - for col in basic_cols: - self.assertIn(col, df.columns, f"Should have {col} column") - - # Check that indicators are calculated - indicator_cols = [col for col in df.columns if col not in basic_cols] - self.assertGreater(len(indicator_cols), 0, "Should have technical indicators") - - logger.info(f"โœ… Successfully calculated {len(indicator_cols)} indicators") - - except Exception as e: - logger.warning(f"Indicator test failed: {e}") - self.skipTest("Data or indicators not available") - - def test_indicator_categorization(self): - """Test categorization of different indicator types""" - logger.info("Testing indicator categorization...") - - try: - df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100) - - if df is not None: - basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] - indicator_cols = [col for col in df.columns if col not in basic_cols] - - # Categorize indicators - trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])] - momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])] - volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])] - volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])] - - # Check we have indicators in each category - total_categorized = len(trend_indicators) + len(momentum_indicators) + len(volatility_indicators) + len(volume_indicators) - - logger.info(f"Indicator categories:") - logger.info(f" Trend: {len(trend_indicators)}") - logger.info(f" Momentum: {len(momentum_indicators)}") - logger.info(f" Volatility: {len(volatility_indicators)}") - logger.info(f" Volume: {len(volume_indicators)}") - logger.info(f" Total categorized: {total_categorized}/{len(indicator_cols)}") - - self.assertGreater(total_categorized, 0, "Should have categorized indicators") - - else: - self.skipTest("Could not fetch data for categorization test") - - except Exception as e: - logger.warning(f"Categorization test failed: {e}") - self.skipTest("Indicator categorization not available") - - def test_feature_matrix_creation(self): - """Test multi-timeframe feature matrix creation""" - logger.info("Testing feature matrix creation...") - - try: - # Test feature matrix with multiple timeframes - feature_matrix = self.data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20) - - if feature_matrix is not None: - self.assertEqual(len(feature_matrix.shape), 3, "Should be 3D matrix") - self.assertGreater(feature_matrix.shape[2], 0, "Should have features") - - logger.info(f"โœ… Feature matrix shape: {feature_matrix.shape}") - - else: - self.skipTest("Could not create feature matrix") - - except Exception as e: - logger.warning(f"Feature matrix test failed: {e}") - self.skipTest("Feature matrix creation not available") - -class TestSignalProcessing(unittest.TestCase): - """Test suite for signal interpretation and processing""" - - def test_signal_distribution_calculation(self): - """Test signal distribution calculation""" - logger.info("Testing signal distribution calculation...") - - # Mock predictions (SELL=0, HOLD=1, BUY=2) - predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) - - buy_count = np.sum(predictions == 2) - sell_count = np.sum(predictions == 0) - hold_count = np.sum(predictions == 1) - total = len(predictions) - - distribution = { - "BUY": buy_count / total, - "SELL": sell_count / total, - "HOLD": hold_count / total - } - - # Verify calculations - self.assertAlmostEqual(distribution["BUY"], 0.3, places=2) - self.assertAlmostEqual(distribution["SELL"], 0.3, places=2) - self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2) - self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2) - - logger.info("โœ… Signal distribution calculation test passed") - - def test_basic_signal_interpretation(self): - """Test basic signal interpretation logic""" - logger.info("Testing basic signal interpretation...") - - # Test cases with different probability distributions - test_cases = [ - { - 'probs': [0.8, 0.1, 0.1], # Strong SELL - 'expected_action': 'SELL', - 'expected_confidence': 'high' - }, - { - 'probs': [0.1, 0.1, 0.8], # Strong BUY - 'expected_action': 'BUY', - 'expected_confidence': 'high' - }, - { - 'probs': [0.1, 0.8, 0.1], # Strong HOLD - 'expected_action': 'HOLD', - 'expected_confidence': 'high' - }, - { - 'probs': [0.4, 0.3, 0.3], # Uncertain - should prefer SELL (index 0) - 'expected_action': 'SELL', - 'expected_confidence': 'low' - }, - { - 'probs': [0.33, 0.33, 0.34], # Very uncertain - slight BUY preference - 'expected_action': 'BUY', - 'expected_confidence': 'low' - } - ] - - for i, test_case in enumerate(test_cases): - probs = np.array(test_case['probs']) - expected_action = test_case['expected_action'] - - # Simple signal interpretation (argmax) - predicted_action_idx = np.argmax(probs) - action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'} - predicted_action = action_map[predicted_action_idx] - - # Calculate confidence (max probability) - confidence = np.max(probs) - confidence_level = 'high' if confidence > 0.7 else 'medium' if confidence > 0.5 else 'low' - - # Verify predictions - self.assertEqual(predicted_action, expected_action, - f"Test case {i+1}: Expected {expected_action}, got {predicted_action}") - - logger.info(f"Test case {i+1}: {probs} -> {predicted_action} ({confidence_level} confidence)") - - logger.info("โœ… Basic signal interpretation test passed") - - def test_signal_filtering_logic(self): - """Test signal filtering and validation logic""" - logger.info("Testing signal filtering logic...") - - # Test threshold-based filtering - buy_threshold = 0.6 - sell_threshold = 0.6 - hold_threshold = 0.7 - - test_signals = [ - { - 'probs': [0.8, 0.1, 0.1], # Strong SELL (above threshold) - 'should_pass': True, - 'expected': 'SELL' - }, - { - 'probs': [0.5, 0.3, 0.2], # Weak SELL (below threshold) - 'should_pass': False, - 'expected': 'HOLD' - }, - { - 'probs': [0.1, 0.2, 0.7], # Strong BUY (above threshold) - 'should_pass': True, - 'expected': 'BUY' - }, - { - 'probs': [0.2, 0.8, 0.0], # Strong HOLD (above threshold) - 'should_pass': True, - 'expected': 'HOLD' - } - ] - - for i, test in enumerate(test_signals): - probs = np.array(test['probs']) - sell_prob, hold_prob, buy_prob = probs - - # Apply threshold filtering - if sell_prob >= sell_threshold: - filtered_action = 'SELL' - passed_filter = True - elif buy_prob >= buy_threshold: - filtered_action = 'BUY' - passed_filter = True - elif hold_prob >= hold_threshold: - filtered_action = 'HOLD' - passed_filter = True - else: - filtered_action = 'HOLD' # Default to HOLD if no threshold met - passed_filter = False - - # Verify filtering - expected_pass = test['should_pass'] - expected_action = test['expected'] - - self.assertEqual(passed_filter, expected_pass, - f"Test {i+1}: Filter pass expectation failed") - self.assertEqual(filtered_action, expected_action, - f"Test {i+1}: Expected {expected_action}, got {filtered_action}") - - logger.info(f"Test {i+1}: {probs} -> {filtered_action} (passed: {passed_filter})") - - logger.info("โœ… Signal filtering logic test passed") - - def test_signal_sequence_validation(self): - """Test signal sequence validation and oscillation prevention""" - logger.info("Testing signal sequence validation...") - - # Simulate a sequence of signals that might oscillate - signal_sequence = ['BUY', 'SELL', 'BUY', 'SELL', 'HOLD', 'BUY'] - - # Simple oscillation detection - oscillation_count = 0 - for i in range(1, len(signal_sequence)): - if (signal_sequence[i-1] == 'BUY' and signal_sequence[i] == 'SELL') or \ - (signal_sequence[i-1] == 'SELL' and signal_sequence[i] == 'BUY'): - oscillation_count += 1 - - # Count consecutive non-HOLD signals - consecutive_trades = 0 - max_consecutive = 0 - for signal in signal_sequence: - if signal != 'HOLD': - consecutive_trades += 1 - max_consecutive = max(max_consecutive, consecutive_trades) - else: - consecutive_trades = 0 - - # Verify oscillation detection - self.assertGreater(oscillation_count, 0, "Should detect oscillations in test sequence") - self.assertGreater(max_consecutive, 1, "Should detect consecutive trades") - - logger.info(f"Detected {oscillation_count} oscillations and max {max_consecutive} consecutive trades") - logger.info("โœ… Signal sequence validation test passed") - -class TestMarketDataAnalysis(unittest.TestCase): - """Test suite for market data analysis functionality""" - - def test_price_movement_calculation(self): - """Test price movement and trend calculation""" - logger.info("Testing price movement calculation...") - - # Mock price data - prices = np.array([100.0, 101.0, 102.5, 101.8, 103.2, 102.9, 104.1]) - - # Calculate price movements - price_changes = np.diff(prices) - percentage_changes = (price_changes / prices[:-1]) * 100 - - # Calculate simple trend - recent_trend = np.mean(percentage_changes[-3:]) # Last 3 changes - trend_direction = 'uptrend' if recent_trend > 0.1 else 'downtrend' if recent_trend < -0.1 else 'sideways' - - # Verify calculations - self.assertEqual(len(price_changes), len(prices) - 1, "Should have n-1 price changes") - self.assertEqual(len(percentage_changes), len(prices) - 1, "Should have n-1 percentage changes") - - # Verify trend detection makes sense - self.assertIn(trend_direction, ['uptrend', 'downtrend', 'sideways'], "Should detect valid trend") - - logger.info(f"Price sequence: {prices}") - logger.info(f"Recent trend: {trend_direction} ({recent_trend:.2f}%)") - logger.info("โœ… Price movement calculation test passed") - - def test_volatility_measurement(self): - """Test volatility measurement""" - logger.info("Testing volatility measurement...") - - # Mock price data with different volatility - stable_prices = np.array([100.0, 100.1, 99.9, 100.2, 99.8, 100.0]) - volatile_prices = np.array([100.0, 105.0, 95.0, 110.0, 90.0, 115.0]) - - # Calculate volatility (standard deviation of returns) - def calculate_volatility(prices): - returns = np.diff(prices) / prices[:-1] - return np.std(returns) * 100 # As percentage - - stable_vol = calculate_volatility(stable_prices) - volatile_vol = calculate_volatility(volatile_prices) - - # Verify volatility measurements - self.assertLess(stable_vol, volatile_vol, "Stable prices should have lower volatility") - self.assertGreater(volatile_vol, 5.0, "Volatile prices should have significant volatility") - - logger.info(f"Stable volatility: {stable_vol:.2f}%") - logger.info(f"Volatile volatility: {volatile_vol:.2f}%") - logger.info("โœ… Volatility measurement test passed") - -def run_indicator_tests(): - """Run indicator tests only""" - suite = unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(suite) - return result.wasSuccessful() - -def run_signal_tests(): - """Run signal processing tests only""" - test_suites = [ - unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing), - unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis), - ] - - combined_suite = unittest.TestSuite(test_suites) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(combined_suite) - return result.wasSuccessful() - -def run_all_tests(): - """Run all indicator and signal tests""" - test_suites = [ - unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators), - unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing), - unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis), - ] - - combined_suite = unittest.TestSuite(test_suites) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(combined_suite) - return result.wasSuccessful() - -if __name__ == "__main__": - setup_logging() - logger.info("Running indicators and signals test suite...") - - if len(sys.argv) > 1: - test_type = sys.argv[1] - if test_type == "indicators": - success = run_indicator_tests() - elif test_type == "signals": - success = run_signal_tests() - else: - success = run_all_tests() - else: - success = run_all_tests() - - if success: - logger.info("โœ… All indicator and signal tests passed!") - sys.exit(0) - else: - logger.error("โŒ Some tests failed!") - sys.exit(1) \ No newline at end of file diff --git a/tests/test_leverage_slider.py b/tests/test_leverage_slider.py deleted file mode 100644 index ed7e1e8..0000000 --- a/tests/test_leverage_slider.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Leverage Slider Functionality - -This script tests the leverage slider integration in the dashboard: -- Verifies slider range (1x to 100x) -- Tests risk level calculation -- Checks leverage multiplier updates -""" - -import sys -import os -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.config import setup_logging -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from web.clean_dashboard import CleanTradingDashboard as TradingDashboard - -# Setup logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def test_leverage_calculations(): - """Test leverage risk calculations""" - - logger.info("=" * 50) - logger.info("TESTING LEVERAGE CALCULATIONS") - logger.info("=" * 50) - - test_cases = [ - {'leverage': 1, 'expected_risk': 'Low Risk'}, - {'leverage': 5, 'expected_risk': 'Low Risk'}, - {'leverage': 10, 'expected_risk': 'Medium Risk'}, - {'leverage': 25, 'expected_risk': 'Medium Risk'}, - {'leverage': 30, 'expected_risk': 'High Risk'}, - {'leverage': 50, 'expected_risk': 'High Risk'}, - {'leverage': 75, 'expected_risk': 'Extreme Risk'}, - {'leverage': 100, 'expected_risk': 'Extreme Risk'}, - ] - - for test_case in test_cases: - leverage = test_case['leverage'] - expected_risk = test_case['expected_risk'] - - # Calculate risk level using same logic as dashboard - if leverage <= 5: - actual_risk = "Low Risk" - elif leverage <= 25: - actual_risk = "Medium Risk" - elif leverage <= 50: - actual_risk = "High Risk" - else: - actual_risk = "Extreme Risk" - - status = "PASS" if actual_risk == expected_risk else "FAIL" - logger.info(f" {leverage:3d}x leverage -> {actual_risk:13s} (expected: {expected_risk:13s}) [{status}]") - - if status == "FAIL": - logger.error(f"Test failed for {leverage}x leverage!") - return False - - logger.info("All leverage calculation tests PASSED!") - return True - -def test_leverage_reward_amplification(): - """Test how different leverage levels amplify rewards""" - - logger.info("\n" + "=" * 50) - logger.info("TESTING LEVERAGE REWARD AMPLIFICATION") - logger.info("=" * 50) - - base_price = 3000.0 - price_changes = [0.001, 0.002, 0.005, 0.01] # 0.1%, 0.2%, 0.5%, 1.0% - leverage_levels = [1, 5, 10, 25, 50, 100] - - logger.info("Price Change | " + " | ".join([f"{lev:3d}x" for lev in leverage_levels])) - logger.info("-" * 70) - - for price_change_pct in price_changes: - results = [] - for leverage in leverage_levels: - # Calculate amplified return - amplified_return = price_change_pct * leverage * 100 # Convert to percentage - results.append(f"{amplified_return:6.1f}%") - - logger.info(f" {price_change_pct*100:4.1f}% | " + " | ".join(results)) - - logger.info("\nKey insights:") - logger.info("- 1x leverage: Traditional trading returns") - logger.info("- 50x leverage: Our current default for enhanced learning") - logger.info("- 100x leverage: Maximum risk/reward amplification") - - return True - -def test_dashboard_integration(): - """Test dashboard integration""" - - logger.info("\n" + "=" * 50) - logger.info("TESTING DASHBOARD INTEGRATION") - logger.info("=" * 50) - - try: - # Initialize components - logger.info("Creating data provider...") - data_provider = DataProvider() - - logger.info("Creating enhanced orchestrator...") - orchestrator = EnhancedTradingOrchestrator(data_provider) - - logger.info("Creating trading dashboard...") - dashboard = TradingDashboard(data_provider, orchestrator) - - # Test leverage settings - logger.info(f"Initial leverage: {dashboard.leverage_multiplier}x") - logger.info(f"Leverage range: {dashboard.min_leverage}x to {dashboard.max_leverage}x") - logger.info(f"Leverage step: {dashboard.leverage_step}x") - - # Test leverage updates - test_leverages = [10, 25, 50, 75] - for test_leverage in test_leverages: - dashboard.leverage_multiplier = test_leverage - logger.info(f"Set leverage to {dashboard.leverage_multiplier}x") - - logger.info("Dashboard integration test PASSED!") - return True - - except Exception as e: - logger.error(f"Dashboard integration test FAILED: {e}") - return False - -def main(): - """Run all leverage tests""" - - logger.info("LEVERAGE SLIDER FUNCTIONALITY TEST") - logger.info("Testing the 50x leverage system with adjustable slider") - - all_passed = True - - # Test 1: Leverage calculations - if not test_leverage_calculations(): - all_passed = False - - # Test 2: Reward amplification - if not test_leverage_reward_amplification(): - all_passed = False - - # Test 3: Dashboard integration - if not test_dashboard_integration(): - all_passed = False - - # Final result - logger.info("\n" + "=" * 50) - if all_passed: - logger.info("ALL TESTS PASSED!") - logger.info("Leverage slider functionality is working correctly.") - logger.info("\nTo use:") - logger.info("1. Run: python run_clean_dashboard.py") - logger.info("2. Open: http://127.0.0.1:8050") - logger.info("3. Find the leverage slider in the System & Leverage panel") - logger.info("4. Adjust leverage from 1x to 100x") - logger.info("5. Watch risk levels update automatically") - else: - logger.error("SOME TESTS FAILED!") - logger.error("Check the error messages above.") - - return 0 if all_passed else 1 - -if __name__ == "__main__": - exit_code = main() - sys.exit(exit_code) \ No newline at end of file diff --git a/tests/test_manual_trading.py b/tests/test_manual_trading.py deleted file mode 100644 index 433c15a..0000000 --- a/tests/test_manual_trading.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for manual trading buttons functionality -""" - -import requests -import json -import time -from datetime import datetime - -def test_manual_trading(): - """Test the manual trading buttons functionality""" - print("Testing manual trading buttons...") - - # Check if dashboard is running - try: - response = requests.get("http://127.0.0.1:8050", timeout=5) - if response.status_code == 200: - print("โœ… Dashboard is running on port 8050") - else: - print(f"โŒ Dashboard returned status code: {response.status_code}") - return - except Exception as e: - print(f"โŒ Dashboard not accessible: {e}") - return - - # Check if trades file exists - try: - with open('closed_trades_history.json', 'r') as f: - trades = json.load(f) - print(f"๐Ÿ“Š Current trades in history: {len(trades)}") - if trades: - latest_trade = trades[-1] - print(f" Latest trade: {latest_trade.get('side')} at ${latest_trade.get('exit_price', latest_trade.get('entry_price'))}") - except FileNotFoundError: - print("๐Ÿ“Š No trades history file found (this is normal for fresh start)") - except Exception as e: - print(f"โŒ Error reading trades file: {e}") - - print("\n๐ŸŽฏ Manual Trading Test Instructions:") - print("1. Open dashboard at http://127.0.0.1:8050") - print("2. Look for the 'MANUAL BUY' and 'MANUAL SELL' buttons") - print("3. Click 'MANUAL BUY' to create a test long position") - print("4. Wait a few seconds, then click 'MANUAL SELL' to close and create short") - print("5. Check the chart for green triangles showing trade entry/exit points") - print("6. Check the 'Closed Trades' table for trade records") - - print("\n๐Ÿ“ˆ Expected Results:") - print("- Green triangles should appear on the price chart at trade execution times") - print("- Dashed lines should connect entry and exit points") - print("- Trade records should appear in the closed trades table") - print("- Session P&L should update with trade profits/losses") - - print("\n๐Ÿ” Monitoring trades file...") - initial_count = 0 - try: - with open('closed_trades_history.json', 'r') as f: - initial_count = len(json.load(f)) - except: - pass - - print(f"Initial trade count: {initial_count}") - print("Watching for new trades... (Press Ctrl+C to stop)") - - try: - while True: - time.sleep(2) - try: - with open('closed_trades_history.json', 'r') as f: - current_trades = json.load(f) - current_count = len(current_trades) - - if current_count > initial_count: - new_trades = current_trades[initial_count:] - for trade in new_trades: - print(f"๐Ÿ†• NEW TRADE: {trade.get('side')} | Entry: ${trade.get('entry_price'):.2f} | Exit: ${trade.get('exit_price'):.2f} | P&L: ${trade.get('net_pnl'):.2f}") - initial_count = current_count - - except FileNotFoundError: - pass - except Exception as e: - print(f"Error monitoring trades: {e}") - - except KeyboardInterrupt: - print("\nโœ… Test monitoring stopped") - -if __name__ == "__main__": - test_manual_trading() \ No newline at end of file diff --git a/tests/test_mexc_account_private.py b/tests/test_mexc_account_private.py deleted file mode 100644 index 34ddff9..0000000 --- a/tests/test_mexc_account_private.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import logging -import sys - -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Add project root to path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -from NN.exchanges.mexc_interface import MEXCInterface -from core.config import get_config - -def test_mexc_private_api(): - """Test MEXC private API endpoints""" - # Load configuration - config = get_config('config.yaml') - mexc_config = config.get('mexc_trading', {}) - - # Get API credentials - api_key = os.getenv('MEXC_API_KEY', mexc_config.get('api_key', '')) - api_secret = os.getenv('MEXC_SECRET_KEY', mexc_config.get('api_secret', '')) - - if not api_key or not api_secret: - logger.error("API key or secret not found. Please set MEXC_API_KEY and MEXC_SECRET_KEY environment variables.") - return - - # Initialize MEXC interface in test mode - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True, trading_mode='simulation') - - # Test connection - if not mexc.connect(): - logger.error("Failed to connect to MEXC API") - return - - # Test getting account information - logger.info("Testing account information retrieval...") - account_info = mexc.get_account_info() - if account_info: - logger.info(f"Account info retrieved: {account_info}") - else: - logger.error("Failed to retrieve account info") - - # Test getting balance for a specific asset - asset = "USDT" - logger.info(f"Testing balance retrieval for {asset}...") - balance = mexc.get_balance(asset) - logger.info(f"Balance for {asset}: {balance}") - - # Test placing a simulated order (in test mode) - symbol = "ETH/USDT" - side = "buy" - order_type = "market" - quantity = 0.01 # Small quantity for testing - logger.info(f"Testing order placement for {symbol} ({side}, {order_type}, qty: {quantity})...") - order_result = mexc.place_order(symbol=symbol, side=side, order_type=order_type, quantity=quantity) - if order_result: - logger.info(f"Order placed successfully: {order_result}") - else: - logger.error("Failed to place order") - -if __name__ == "__main__": - test_mexc_private_api() \ No newline at end of file diff --git a/tests/test_mexc_account_privte.py b/tests/test_mexc_account_privte.py deleted file mode 100644 index c796297..0000000 --- a/tests/test_mexc_account_privte.py +++ /dev/null @@ -1,59 +0,0 @@ -import logging -import os -from NN.exchanges.mexc_interface import MEXCInterface - -# Set up logging to see debug info -logging.basicConfig(level=logging.INFO) - -# Load API credentials from environment variables or a configuration file -# For testing, prioritize environment variables for CI/CD or sensitive data -# Fallback to a placeholder or configuration reading if env vars are not set -api_key = os.getenv('MEXC_API_KEY', '') -api_secret = os.getenv('MEXC_SECRET_KEY', '') - -# If using a config file, you might do something like: -# from core.config import get_config -# config = get_config('config.yaml') -# mexc_config = config.get('mexc_trading', {}) -# api_key = mexc_config.get('api_key', api_key) -# api_secret = mexc_config.get('api_secret', api_secret) - -if not api_key or not api_secret: - logging.error("API keys are not set. Please set MEXC_API_KEY and MEXC_SECRET_KEY environment variables or configure config.yaml") - exit(1) - -# Create interface with API credentials -mexc = MEXCInterface( - api_key=api_key, - api_secret=api_secret, - trading_mode='simulation' -) - -print('MEXC Interface created successfully') - -# Test signature generation -import time -timestamp = int(time.time() * 1000) -test_params = 'quantity=1&price=11&symbol=BTCUSDT&side=BUY&type=LIMIT×tamp=' + str(timestamp) -signature = mexc._generate_signature(timestamp, test_params) -print(f'Generated signature: {signature}') - -# Test account info -print('Testing account info...') -account_info = mexc.get_account_info() -print(f'Account info result: {account_info}') - -# Test ticker data -print('Testing ticker data...') -ticker = mexc.get_ticker('ETH/USDT') -print(f'ETH/USDT ticker: {ticker}') - -# Test balance retrieval -print('Testing balance retrieval...') -usdt_balance = mexc.get_balance('USDT') -print(f'USDT balance: {usdt_balance}') - -# Test a small order placement (simulation mode) -print('Testing order placement in simulation mode...') -order_result = mexc.place_order('ETH/USDT', 'buy', 'market', 0.001) -print(f'Order result: {order_result}') \ No newline at end of file diff --git a/tests/test_mexc_balance_orders.py b/tests/test_mexc_balance_orders.py deleted file mode 100644 index 0fbf701..0000000 --- a/tests/test_mexc_balance_orders.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for MEXC balance retrieval and $1 order execution -""" - -import sys -import os -import logging -from pathlib import Path - -# Add project root to path -sys.path.insert(0, os.path.abspath('.')) - -from core.trading_executor import TradingExecutor -from core.data_provider import DataProvider - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_mexc_balance(): - """Test MEXC balance retrieval""" - print("="*60) - print("TESTING MEXC BALANCE RETRIEVAL") - print("="*60) - - try: - # Initialize trading executor - executor = TradingExecutor() - - # Check if trading is enabled - print(f"Trading enabled: {executor.trading_enabled}") - print(f"Dry run mode: {executor.dry_run}") - - if not executor.trading_enabled: - print("โŒ Trading not enabled - check config.yaml and API keys") - return False - - # Test balance retrieval - print("\n๐Ÿ“Š Retrieving account balance...") - balances = executor.get_account_balance() - - if not balances: - print("โŒ No balances retrieved - check API connectivity") - return False - - print(f"โœ… Retrieved balances for {len(balances)} assets:") - for asset, balance_info in balances.items(): - free = balance_info['free'] - locked = balance_info['locked'] - total = balance_info['total'] - print(f" {asset}: Free: {free:.6f}, Locked: {locked:.6f}, Total: {total:.6f}") - - # Check USDT balance specifically - if 'USDT' in balances: - usdt_free = balances['USDT']['free'] - print(f"\n๐Ÿ’ฐ USDT available for trading: ${usdt_free:.2f}") - - if usdt_free >= 2.0: # Need at least $2 for testing - print("โœ… Sufficient USDT balance for $1 order testing") - return True - else: - print(f"โš ๏ธ Insufficient USDT balance for testing (need $2+, have ${usdt_free:.2f})") - return False - else: - print("โŒ No USDT balance found") - return False - - except Exception as e: - logger.error(f"Error testing MEXC balance: {e}") - return False - -def test_mexc_order_execution(): - """Test $1 order execution (dry run)""" - print("\n" + "="*60) - print("TESTING $1 ORDER EXECUTION (DRY RUN)") - print("="*60) - - try: - # Initialize components - executor = TradingExecutor() - data_provider = DataProvider() - - if not executor.trading_enabled: - print("โŒ Trading not enabled - cannot test order execution") - return False - - # Test symbol - symbol = "ETH/USDT" - - # Get current price - print(f"\n๐Ÿ“ˆ Getting current price for {symbol}...") - ticker_data = data_provider.get_historical_data(symbol, '1m', limit=1, refresh=True) - - if ticker_data is None or ticker_data.empty: - print(f"โŒ Could not get price data for {symbol}") - return False - - current_price = float(ticker_data['close'].iloc[-1]) - print(f"โœ… Current {symbol} price: ${current_price:.2f}") - - # Calculate order size for $1 - usd_amount = 1.0 - crypto_amount = usd_amount / current_price - print(f"๐Ÿ’ฑ $1 USD = {crypto_amount:.6f} ETH") - - # Test buy signal execution - print(f"\n๐Ÿ›’ Testing BUY signal execution...") - buy_success = executor.execute_signal( - symbol=symbol, - action='BUY', - confidence=0.75, - current_price=current_price - ) - - if buy_success: - print("โœ… BUY signal executed successfully") - - # Check position - positions = executor.get_positions() - if symbol in positions: - position = positions[symbol] - print(f"๐Ÿ“ Position opened: {position.quantity:.6f} {symbol} @ ${position.entry_price:.2f}") - - # Test sell signal execution - print(f"\n๐Ÿ’ฐ Testing SELL signal execution...") - sell_success = executor.execute_signal( - symbol=symbol, - action='SELL', - confidence=0.80, - current_price=current_price * 1.001 # Simulate small price increase - ) - - if sell_success: - print("โœ… SELL signal executed successfully") - - # Check trade history - trades = executor.get_trade_history() - if trades: - last_trade = trades[-1] - print(f"๐Ÿ“Š Trade completed: P&L = ${last_trade.pnl:.4f}") - - return True - else: - print("โŒ SELL signal failed") - return False - else: - print("โŒ No position found after BUY signal") - return False - else: - print("โŒ BUY signal failed") - return False - - except Exception as e: - logger.error(f"Error testing order execution: {e}") - return False - -def test_dashboard_balance_integration(): - """Test dashboard balance integration""" - print("\n" + "="*60) - print("TESTING DASHBOARD BALANCE INTEGRATION") - print("="*60) - - try: - from web.clean_dashboard import CleanTradingDashboard as TradingDashboard - - # Create dashboard with trading executor - executor = TradingExecutor() - dashboard = TradingDashboard(trading_executor=executor) - - print(f"Dashboard starting balance: ${dashboard.starting_balance:.2f}") - - if dashboard.starting_balance > 0: - print("โœ… Dashboard successfully retrieved starting balance") - return True - else: - print("โš ๏ธ Dashboard using default balance (MEXC not connected)") - return False - - except Exception as e: - logger.error(f"Error testing dashboard integration: {e}") - return False - -def main(): - """Run all tests""" - print("๐Ÿš€ MEXC INTEGRATION TESTING") - print("Testing balance retrieval and $1 order execution") - - # Test 1: Balance retrieval - balance_test = test_mexc_balance() - - # Test 2: Order execution (only if balance test passes) - if balance_test: - order_test = test_mexc_order_execution() - else: - print("\nโญ๏ธ Skipping order execution test (balance test failed)") - order_test = False - - # Test 3: Dashboard integration - dashboard_test = test_dashboard_balance_integration() - - # Summary - print("\n" + "="*60) - print("TEST SUMMARY") - print("="*60) - print(f"Balance Retrieval: {'โœ… PASS' if balance_test else 'โŒ FAIL'}") - print(f"Order Execution: {'โœ… PASS' if order_test else 'โŒ FAIL'}") - print(f"Dashboard Integration: {'โœ… PASS' if dashboard_test else 'โŒ FAIL'}") - - if balance_test and order_test and dashboard_test: - print("\n๐ŸŽ‰ ALL TESTS PASSED - Ready for live $1 testing!") - return True - else: - print("\nโš ๏ธ Some tests failed - check configuration and API keys") - return False - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/test_mexc_data_integration.py b/tests/test_mexc_data_integration.py deleted file mode 100644 index 0519ecb..0000000 --- a/tests/test_mexc_data_integration.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/tests/test_mexc_new_keys.py b/tests/test_mexc_new_keys.py deleted file mode 100644 index 0ea1660..0000000 --- a/tests/test_mexc_new_keys.py +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for MEXC API with new credentials -""" - -import os -import sys -from dotenv import load_dotenv - -# Load environment variables -load_dotenv() - -def test_api_credentials(): - """Test MEXC API credentials step by step""" - print("="*60) - print("MEXC API CREDENTIALS TEST") - print("="*60) - - # Check environment variables - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - print(f"1. Environment Variables:") - print(f" API Key: {api_key[:5]}...{api_key[-5:] if api_key else 'None'}") - print(f" API Secret: {api_secret[:5]}...{api_secret[-5:] if api_secret else 'None'}") - print(f" API Key Length: {len(api_key) if api_key else 0}") - print(f" API Secret Length: {len(api_secret) if api_secret else 0}") - - if not api_key or not api_secret: - print("โŒ API credentials not found in environment") - return False - - # Test public API first - print(f"\n2. Testing Public API (no authentication):") - try: - from NN.exchanges.mexc_interface import MEXCInterface - api = MEXCInterface('dummy', 'dummy', test_mode=False) - - ticker = api.get_ticker('ETHUSDT') - if ticker: - print(f" โœ… Public API works: ETH/USDT = ${ticker.get('last', 'N/A')}") - else: - print(f" โŒ Public API failed") - return False - except Exception as e: - print(f" โŒ Public API error: {e}") - return False - - # Test private API with actual credentials - print(f"\n3. Testing Private API (with authentication):") - try: - api_auth = MEXCInterface(api_key, api_secret, test_mode=False) - - # Try to get account info - account_info = api_auth.get_account_info() - if account_info: - print(f" โœ… Private API works: Account info retrieved") - print(f" ๐Ÿ“Š Account Type: {account_info.get('accountType', 'N/A')}") - # Try to get USDT balance - usdt_balance = api_auth.get_balance('USDT') - print(f" ๐Ÿ’ฐ USDT Balance: {usdt_balance}") - return True - else: - print(f" โŒ Private API failed: Could not get account info") - return False - - except Exception as e: - print(f" โŒ Private API error: {e}") - return False - -def test_api_permissions(): - """Test specific API permissions""" - print(f"\n4. Testing API Permissions:") - - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - try: - from NN.exchanges.mexc_interface import MEXCInterface - api = MEXCInterface(api_key, api_secret, test_mode=False) - - # Test spot trading permissions - print(" Testing spot trading permissions...") - - # Try to get open orders (requires spot trading permission) - try: - orders = api.get_open_orders('ETHUSDT') - print(" โœ… Spot trading permission: OK") - except Exception as e: - print(f" โŒ Spot trading permission: {e}") - return False - - return True - - except Exception as e: - print(f" โŒ Permission test error: {e}") - return False - -def main(): - """Main test function""" - success = test_api_credentials() - - if success: - test_api_permissions() - print(f"\nโœ… MEXC API SETUP COMPLETE") - print("The trading system should now work with live MEXC spot trading") - else: - print(f"\nโŒ MEXC API SETUP FAILED") - print("Possible issues:") - print("1. API key or secret incorrect") - print("2. API key not activated yet") - print("3. Insufficient permissions (need spot trading)") - print("4. IP address not whitelisted") - print("5. Account verification incomplete") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_mexc_order_debug.py b/tests/test_mexc_order_debug.py deleted file mode 100644 index 62d79af..0000000 --- a/tests/test_mexc_order_debug.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -MEXC Order Execution Debug Script - -This script tests MEXC order execution step by step to identify any issues -with the trading integration. -""" - -import os -import sys -import logging -import time -from datetime import datetime -from typing import Dict, Any -from dotenv import load_dotenv - -# Add paths for imports -sys.path.append(os.path.join(os.path.dirname(__file__), 'core')) -sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) - -from core.trading_executor import TradingExecutor -from core.data_provider import DataProvider -from NN.exchanges.mexc_interface import MEXCInterface - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("mexc_order_debug.log"), - logging.StreamHandler() - ] -) -logger = logging.getLogger("mexc_debug") - -class MEXCOrderDebugger: - """Debug MEXC order execution step by step""" - - def __init__(self): - self.test_symbol = 'ETH/USDC' # ETH with USDC (supported by MEXC API) - self.test_quantity = 0.15 # $0.15 worth of ETH for testing (within our balance) - - # Load environment variables - load_dotenv() - - self.api_key = os.getenv('MEXC_API_KEY') - self.api_secret = os.getenv('MEXC_SECRET_KEY') - - def run_comprehensive_test(self): - """Run comprehensive MEXC order execution test""" - print("="*80) - print("MEXC ORDER EXECUTION DEBUG TEST") - print("="*80) - - # Step 1: Test environment variables - print("\n1. Testing Environment Variables...") - if not self.test_environment_variables(): - return False - - # Step 2: Test MEXC interface creation - print("\n2. Testing MEXC Interface Creation...") - mexc = self.test_mexc_interface_creation() - if not mexc: - return False - - # Step 3: Test connection - print("\n3. Testing MEXC Connection...") - if not self.test_mexc_connection(mexc): - return False - - # Step 4: Test account info - print("\n4. Testing Account Information...") - if not self.test_account_info(mexc): - return False - - # Step 5: Test ticker data - print("\n5. Testing Ticker Data...") - current_price = self.test_ticker_data(mexc) - if not current_price: - return False - - # Step 6: Test trading executor - print("\n6. Testing Trading Executor...") - executor = self.test_trading_executor_creation() - if not executor: - return False - - # Step 7: Test order placement (simulation) - print("\n7. Testing Order Placement...") - if not self.test_order_placement(executor, current_price): - return False - - # Step 8: Test order parameters - print("\n8. Testing Order Parameters...") - if not self.test_order_parameters(mexc, current_price): - return False - - print("\n" + "="*80) - print("โœ… ALL TESTS COMPLETED SUCCESSFULLY!") - print("MEXC order execution system appears to be working correctly.") - print("="*80) - - return True - - def test_environment_variables(self) -> bool: - """Test environment variables""" - try: - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key: - print("โŒ MEXC_API_KEY environment variable not set") - return False - - if not api_secret: - print("โŒ MEXC_SECRET_KEY environment variable not set") - return False - - print(f"โœ… MEXC_API_KEY: {api_key[:8]}...{api_key[-4:]} (length: {len(api_key)})") - print(f"โœ… MEXC_SECRET_KEY: {api_secret[:8]}...{api_secret[-4:]} (length: {len(api_secret)})") - - return True - - except Exception as e: - print(f"โŒ Error checking environment variables: {e}") - return False - - def test_mexc_interface_creation(self) -> MEXCInterface: - """Test MEXC interface creation""" - try: - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - mexc = MEXCInterface( - api_key=api_key, - api_secret=api_secret, - test_mode=True # Use testnet for safety - ) - - print(f"โœ… MEXC Interface created successfully") - print(f" - Test mode: {mexc.test_mode}") - print(f" - Base URL: {mexc.base_url}") - - return mexc - - except Exception as e: - print(f"โŒ Error creating MEXC interface: {e}") - return None - - def test_mexc_connection(self, mexc: MEXCInterface) -> bool: - """Test MEXC connection""" - try: - # Test ping - ping_result = mexc.ping() - print(f"โœ… MEXC Ping successful: {ping_result}") - - # Test server time - server_time = mexc.get_server_time() - print(f"โœ… MEXC Server time: {server_time}") - - # Test connection method - connected = mexc.connect() - print(f"โœ… MEXC Connection: {connected}") - - return True - - except Exception as e: - print(f"โŒ Error testing MEXC connection: {e}") - logger.error(f"MEXC connection error: {e}", exc_info=True) - return False - - def test_account_info(self, mexc: MEXCInterface) -> bool: - """Test account information retrieval""" - try: - account_info = mexc.get_account_info() - print(f"โœ… Account info retrieved successfully") - print(f" - Can trade: {account_info.get('canTrade', 'Unknown')}") - print(f" - Can withdraw: {account_info.get('canWithdraw', 'Unknown')}") - print(f" - Can deposit: {account_info.get('canDeposit', 'Unknown')}") - print(f" - Account type: {account_info.get('accountType', 'Unknown')}") - - # Test balance retrieval - balances = account_info.get('balances', []) - usdc_balance = 0 - usdt_balance = 0 - for balance in balances: - if balance.get('asset') == 'USDC': - usdc_balance = float(balance.get('free', 0)) - elif balance.get('asset') == 'USDT': - usdt_balance = float(balance.get('free', 0)) - - print(f" - USDC Balance: {usdc_balance}") - print(f" - USDT Balance: {usdt_balance}") - - if usdc_balance < self.test_quantity: - print(f"โš ๏ธ Warning: USDC balance ({usdc_balance}) is less than test amount ({self.test_quantity})") - if usdt_balance >= self.test_quantity: - print(f"๐Ÿ’ก Note: You have sufficient USDT ({usdt_balance}), but we need USDC for ETH/USDC trading") - - return True - - except Exception as e: - print(f"โŒ Error retrieving account info: {e}") - logger.error(f"Account info error: {e}", exc_info=True) - return False - - def test_ticker_data(self, mexc: MEXCInterface) -> float: - """Test ticker data retrieval""" - try: - ticker = mexc.get_ticker(self.test_symbol) - if not ticker: - print(f"โŒ Failed to get ticker for {self.test_symbol}") - return None - - current_price = ticker['last'] - print(f"โœ… Ticker data retrieved for {self.test_symbol}") - print(f" - Last price: ${current_price:.2f}") - print(f" - Bid: ${ticker.get('bid', 0):.2f}") - print(f" - Ask: ${ticker.get('ask', 0):.2f}") - print(f" - Volume: {ticker.get('volume', 0)}") - - return current_price - - except Exception as e: - print(f"โŒ Error retrieving ticker data: {e}") - logger.error(f"Ticker data error: {e}", exc_info=True) - return None - - def test_trading_executor_creation(self) -> TradingExecutor: - """Test trading executor creation""" - try: - executor = TradingExecutor() - print(f"โœ… Trading Executor created successfully") - print(f" - Trading enabled: {executor.trading_enabled}") - print(f" - Trading mode: {executor.trading_mode}") - print(f" - Simulation mode: {executor.simulation_mode}") - - return executor - - except Exception as e: - print(f"โŒ Error creating trading executor: {e}") - logger.error(f"Trading executor error: {e}", exc_info=True) - return None - - def test_order_placement(self, executor: TradingExecutor, current_price: float) -> bool: - """Test order placement through executor""" - try: - print(f"Testing BUY signal execution...") - - # Test BUY signal - buy_success = executor.execute_signal( - symbol=self.test_symbol, - action='BUY', - confidence=0.75, - current_price=current_price - ) - - print(f"โœ… BUY signal execution: {'SUCCESS' if buy_success else 'FAILED'}") - - if buy_success: - # Check positions - positions = executor.get_positions() - if self.test_symbol in positions: - position = positions[self.test_symbol] - print(f" - Position created: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}") - - # Test SELL signal - print(f"Testing SELL signal execution...") - sell_success = executor.execute_signal( - symbol=self.test_symbol, - action='SELL', - confidence=0.80, - current_price=current_price * 1.001 # Simulate small price increase - ) - - print(f"โœ… SELL signal execution: {'SUCCESS' if sell_success else 'FAILED'}") - - if sell_success: - # Check trade history - trades = executor.get_trade_history() - if trades: - last_trade = trades[-1] - print(f" - Trade P&L: ${last_trade.pnl:.4f}") - - return sell_success - else: - print("โŒ No position found after BUY signal") - return False - - return buy_success - - except Exception as e: - print(f"โŒ Error testing order placement: {e}") - logger.error(f"Order placement error: {e}", exc_info=True) - return False - - def test_order_parameters(self, mexc: MEXCInterface, current_price: float) -> bool: - """Test order parameters and validation""" - try: - print("Testing order parameter calculation...") - - # Calculate test order size - crypto_quantity = self.test_quantity / current_price - print(f" - USD amount: ${self.test_quantity}") - print(f" - Current price: ${current_price:.2f}") - print(f" - Crypto quantity: {crypto_quantity:.6f} ETH") - - # Test order parameters formatting - mexc_symbol = self.test_symbol.replace('/', '') - print(f" - MEXC symbol format: {mexc_symbol}") - - order_params = { - 'symbol': mexc_symbol, - 'side': 'BUY', - 'type': 'MARKET', - 'quantity': str(crypto_quantity), - 'recvWindow': 5000 - } - - print(f" - Order parameters: {order_params}") - - # Test signature generation (without actually placing order) - print("Testing signature generation...") - test_params = order_params.copy() - test_params['timestamp'] = int(time.time() * 1000) - - try: - signature = mexc._generate_signature(test_params) - print(f"โœ… Signature generated successfully (length: {len(signature)})") - except Exception as e: - print(f"โŒ Signature generation failed: {e}") - return False - - print("โœ… Order parameters validation successful") - return True - - except Exception as e: - print(f"โŒ Error testing order parameters: {e}") - logger.error(f"Order parameters error: {e}", exc_info=True) - return False - -def main(): - """Main test function""" - try: - debugger = MEXCOrderDebugger() - success = debugger.run_comprehensive_test() - - if success: - print("\n๐ŸŽ‰ MEXC order execution system is working correctly!") - print("You can now safely execute live trades.") - else: - print("\n๐Ÿšจ MEXC order execution has issues that need to be resolved.") - print("Check the logs above for specific error details.") - - return success - - except Exception as e: - logger.error(f"Error in main test: {e}", exc_info=True) - print(f"\nโŒ Critical error during testing: {e}") - return False - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_mexc_order_sizes.py b/tests/test_mexc_order_sizes.py deleted file mode 100644 index ea0f4b7..0000000 --- a/tests/test_mexc_order_sizes.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Test MEXC Order Size Requirements - -This script tests different order sizes to identify minimum order requirements -and understand why order placement is failing. -""" - -import os -import sys -import logging -import time - -# Add paths for imports -sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) - -from NN.exchanges.mexc_interface import MEXCInterface - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("mexc_order_test") - -def test_order_sizes(): - """Test different order sizes to find minimum requirements""" - print("="*60) - print("MEXC ORDER SIZE REQUIREMENTS TEST") - print("="*60) - - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key or not api_secret: - print("โŒ Missing API credentials") - return False - - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True) - - # Get current ETH price - ticker = mexc.get_ticker('ETH/USDT') - if not ticker: - print("โŒ Failed to get ETH price") - return False - - current_price = ticker['last'] - print(f"Current ETH price: ${current_price:.2f}") - - # Test different USD amounts - test_amounts_usd = [0.1, 0.5, 1.0, 5.0, 10.0, 20.0] - - print(f"\nTesting different order sizes...") - print(f"{'USD Amount':<12} {'ETH Quantity':<15} {'Min ETH?':<10} {'Min USD?':<10}") - print("-" * 50) - - for usd_amount in test_amounts_usd: - eth_quantity = usd_amount / current_price - - # Check if quantity meets common minimum requirements - min_eth_ok = eth_quantity >= 0.001 # 0.001 ETH common minimum - min_usd_ok = usd_amount >= 5.0 # $5 common minimum - - print(f"${usd_amount:<11.2f} {eth_quantity:<15.6f} {'โœ…' if min_eth_ok else 'โŒ':<9} {'โœ…' if min_usd_ok else 'โŒ':<9}") - - # Test actual order parameter validation - print(f"\nTesting order parameter validation...") - - # Test small order (likely to fail) - small_usd = 1.0 - small_eth = small_usd / current_price - - print(f"\n1. Testing small order: ${small_usd} (${small_eth:.6f} ETH)") - success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', small_eth) - - # Test medium order (might work) - medium_usd = 10.0 - medium_eth = medium_usd / current_price - - print(f"\n2. Testing medium order: ${medium_usd} (${medium_eth:.6f} ETH)") - success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', medium_eth) - - # Test with rounded quantities - print(f"\n3. Testing with rounded quantities...") - - # Test 0.001 ETH (common minimum) - print(f" Testing 0.001 ETH (${0.001 * current_price:.2f})") - success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', 0.001) - - # Test 0.01 ETH - print(f" Testing 0.01 ETH (${0.01 * current_price:.2f})") - success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', 0.01) - - return True - -def test_order_validation(mexc: MEXCInterface, symbol: str, side: str, order_type: str, quantity: float) -> bool: - """Test order parameter validation without actually placing the order""" - try: - # Prepare order parameters - params = { - 'symbol': symbol, - 'side': side, - 'type': order_type, - 'quantity': str(quantity), - 'recvWindow': 5000, - 'timestamp': int(time.time() * 1000) - } - - # Generate signature - signature = mexc._generate_signature(params) - params['signature'] = signature - - print(f" Params: {params}") - - # Try to validate parameters by making the request but catching the specific error - headers = {'X-MEXC-APIKEY': mexc.api_key} - url = f"{mexc.base_url}/{mexc.api_version}/order" - - import requests - - # Make the request to see what specific error we get - response = requests.post(url, params=params, headers=headers, timeout=30) - - if response.status_code == 200: - print(" โœ… Order would be accepted (parameters valid)") - return True - else: - response_data = response.json() if response.headers.get('content-type', '').startswith('application/json') else {'msg': response.text} - error_code = response_data.get('code', 'Unknown') - error_msg = response_data.get('msg', 'Unknown error') - - print(f" โŒ Error {error_code}: {error_msg}") - - # Analyze specific error codes - if error_code == 400001: - print(" โ†’ Invalid parameter format") - elif error_code == 700002: - print(" โ†’ Invalid signature") - elif error_code == 70016: - print(" โ†’ Order size too small") - elif error_code == 70015: - print(" โ†’ Insufficient balance") - elif 'LOT_SIZE' in error_msg: - print(" โ†’ Lot size violation (quantity precision/minimum)") - elif 'MIN_NOTIONAL' in error_msg: - print(" โ†’ Minimum notional value not met") - - return False - - except Exception as e: - print(f" โŒ Exception: {e}") - return False - -def get_symbol_info(): - """Get symbol trading rules and limits""" - print("\nGetting symbol trading rules...") - - try: - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True) - - # Try to get exchange info - import requests - - url = f"{mexc.base_url}/{mexc.api_version}/exchangeInfo" - response = requests.get(url, timeout=30) - - if response.status_code == 200: - exchange_info = response.json() - - # Find ETHUSDT symbol info - for symbol_info in exchange_info.get('symbols', []): - if symbol_info.get('symbol') == 'ETHUSDT': - print(f"Found ETHUSDT trading rules:") - print(f" Status: {symbol_info.get('status')}") - print(f" Base asset: {symbol_info.get('baseAsset')}") - print(f" Quote asset: {symbol_info.get('quoteAsset')}") - - # Check filters - for filter_info in symbol_info.get('filters', []): - filter_type = filter_info.get('filterType') - if filter_type == 'LOT_SIZE': - print(f" Lot Size Filter:") - print(f" Min Qty: {filter_info.get('minQty')}") - print(f" Max Qty: {filter_info.get('maxQty')}") - print(f" Step Size: {filter_info.get('stepSize')}") - elif filter_type == 'MIN_NOTIONAL': - print(f" Min Notional Filter:") - print(f" Min Notional: {filter_info.get('minNotional')}") - elif filter_type == 'PRICE_FILTER': - print(f" Price Filter:") - print(f" Min Price: {filter_info.get('minPrice')}") - print(f" Max Price: {filter_info.get('maxPrice')}") - print(f" Tick Size: {filter_info.get('tickSize')}") - - break - else: - print("โŒ ETHUSDT symbol not found in exchange info") - else: - print(f"โŒ Failed to get exchange info: {response.status_code}") - - except Exception as e: - print(f"โŒ Error getting symbol info: {e}") - -if __name__ == "__main__": - get_symbol_info() - test_order_sizes() \ No newline at end of file diff --git a/tests/test_mexc_public_api.py b/tests/test_mexc_public_api.py deleted file mode 100644 index 09abe86..0000000 --- a/tests/test_mexc_public_api.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for MEXC public API endpoints -""" - -import sys -import os -import logging - -# Add project root to path -sys.path.insert(0, os.path.abspath('.')) - -from NN.exchanges.mexc_interface import MEXCInterface - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_mexc_public_api(): - """Test MEXC public API endpoints""" - print("="*60) - print("TESTING MEXC PUBLIC API") - print("="*60) - - try: - # Initialize MEXC interface without API keys (public access only) - mexc = MEXCInterface() - - print("\n1. Testing server connectivity...") - try: - # Test ping - ping_result = mexc.ping() - print(f"โœ… Ping successful: {ping_result}") - except Exception as e: - print(f"โŒ Ping failed: {e}") - - print("\n2. Testing server time...") - try: - # Test server time - time_result = mexc.get_server_time() - print(f"โœ… Server time: {time_result}") - except Exception as e: - print(f"โŒ Server time failed: {e}") - - print("\n3. Testing ticker data...") - symbols_to_test = ['BTC/USDT', 'ETH/USDT'] - - for symbol in symbols_to_test: - try: - ticker = mexc.get_ticker(symbol) - if ticker: - print(f"โœ… {symbol}: ${ticker['last']:.2f} (bid: ${ticker['bid']:.2f}, ask: ${ticker['ask']:.2f})") - else: - print(f"โŒ {symbol}: No data returned") - except Exception as e: - print(f"โŒ {symbol}: Error - {e}") - - print("\n" + "="*60) - print("PUBLIC API TEST COMPLETED") - print("="*60) - - except Exception as e: - print(f"โŒ Error initializing MEXC interface: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_mexc_public_api() \ No newline at end of file diff --git a/tests/test_mexc_signature.py b/tests/test_mexc_signature.py deleted file mode 100644 index 701f248..0000000 --- a/tests/test_mexc_signature.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -Test MEXC Signature Generation - -This script tests the MEXC signature generation to ensure it's correct -according to the MEXC API documentation. -""" - -import os -import sys -import hashlib -import hmac -from urllib.parse import urlencode -import time -import requests - -# Add paths for imports -sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) - -from NN.exchanges.mexc_interface import MEXCInterface - -def test_signature_generation(): - """Test MEXC signature generation with known examples""" - print("="*60) - print("MEXC SIGNATURE GENERATION TEST") - print("="*60) - - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key or not api_secret: - print("โŒ Missing API credentials") - return False - - print(f"API Key: {api_key[:8]}...{api_key[-4:]}") - print(f"API Secret: {api_secret[:8]}...{api_secret[-4:]}") - - # Test 1: Simple signature generation - print("\n1. Testing basic signature generation...") - - test_params = { - 'symbol': 'ETHUSDT', - 'side': 'BUY', - 'type': 'MARKET', - 'quantity': '0.001', - 'timestamp': 1640995200000, - 'recvWindow': 5000 - } - - # Generate signature manually - sorted_params = sorted(test_params.items()) - query_string = urlencode(sorted_params) - expected_signature = hmac.new( - api_secret.encode('utf-8'), - query_string.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - - print(f"Query string: {query_string}") - print(f"Expected signature: {expected_signature}") - - # Generate signature using MEXC interface - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True) - actual_signature = mexc._generate_signature(test_params) - - print(f"Actual signature: {actual_signature}") - - if expected_signature == actual_signature: - print("โœ… Signature generation matches expected") - else: - print("โŒ Signature generation mismatch") - return False - - # Test 2: Real order parameters - print("\n2. Testing with real order parameters...") - - current_timestamp = int(time.time() * 1000) - real_params = { - 'symbol': 'ETHUSDT', - 'side': 'BUY', - 'type': 'MARKET', - 'quantity': '0.001', - 'timestamp': current_timestamp, - 'recvWindow': 5000 - } - - real_signature = mexc._generate_signature(real_params) - sorted_real_params = sorted(real_params.items()) - real_query_string = urlencode(sorted_real_params) - - print(f"Real timestamp: {current_timestamp}") - print(f"Real query string: {real_query_string}") - print(f"Real signature: {real_signature}") - - # Test 3: Verify parameter ordering - print("\n3. Testing parameter ordering sensitivity...") - - # Test with parameters in different order - unordered_params = { - 'timestamp': current_timestamp, - 'symbol': 'ETHUSDT', - 'recvWindow': 5000, - 'type': 'MARKET', - 'side': 'BUY', - 'quantity': '0.001' - } - - unordered_signature = mexc._generate_signature(unordered_params) - - if real_signature == unordered_signature: - print("โœ… Parameter ordering handled correctly") - else: - print("โŒ Parameter ordering issue") - return False - - # Test 4: Check for common issues - print("\n4. Checking for common signature issues...") - - # Check if any parameters need special encoding - special_params = { - 'symbol': 'ETHUSDT', - 'side': 'BUY', - 'type': 'MARKET', - 'quantity': '0.0028417810009889397', # Full precision from error log - 'timestamp': current_timestamp, - 'recvWindow': 5000 - } - - special_signature = mexc._generate_signature(special_params) - special_sorted = sorted(special_params.items()) - special_query = urlencode(special_sorted) - - print(f"Special quantity: {special_params['quantity']}") - print(f"Special query: {special_query}") - print(f"Special signature: {special_signature}") - - # Test 5: Compare with error log signature - print("\n5. Comparing with error log...") - - # From the error log, we have this signature: - error_log_signature = "2a52436039e24b593ab0ab20ac1a67e2323654dc14190ee2c2cde341930d27d4" - error_timestamp = 1748349875981 - - error_params = { - 'symbol': 'ETHUSDT', - 'side': 'BUY', - 'type': 'MARKET', - 'quantity': '0.0028417810009889397', - 'recvWindow': 5000, - 'timestamp': error_timestamp - } - - recreated_signature = mexc._generate_signature(error_params) - - print(f"Error log signature: {error_log_signature}") - print(f"Recreated signature: {recreated_signature}") - - if error_log_signature == recreated_signature: - print("โœ… Signature recreation matches error log") - else: - print("โŒ Signature recreation doesn't match - potential algorithm issue") - - # Debug the query string - debug_sorted = sorted(error_params.items()) - debug_query = urlencode(debug_sorted) - print(f"Debug query string: {debug_query}") - - # Try manual HMAC - manual_sig = hmac.new( - api_secret.encode('utf-8'), - debug_query.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - print(f"Manual signature: {manual_sig}") - - print("\n" + "="*60) - print("SIGNATURE TEST COMPLETED") - print("="*60) - - return True - -def test_mexc_api_call(): - """Test a simple authenticated API call to verify signature works""" - print("\n6. Testing authenticated API call...") - - try: - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True) - - # Test account info (this should work if signature is correct) - account_info = mexc.get_account_info() - print("โœ… Account info call successful - signature is working") - print(f" Account type: {account_info.get('accountType', 'Unknown')}") - print(f" Can trade: {account_info.get('canTrade', 'Unknown')}") - - return True - - except Exception as e: - print(f"โŒ Account info call failed: {e}") - return False - -# Test exact signature generation for MEXC order placement -def test_mexc_order_signature(): - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key or not api_secret: - print("โŒ MEXC API keys not found") - return - - print("=== MEXC ORDER SIGNATURE DEBUG ===") - print(f"API Key: {api_key[:8]}...{api_key[-4:]}") - print(f"Secret Key: {api_secret[:8]}...{api_secret[-4:]}") - print() - - # Get server time first - try: - time_resp = requests.get('https://api.mexc.com/api/v3/time') - server_time = time_resp.json()['serverTime'] - print(f"โœ… Server time: {server_time}") - except Exception as e: - print(f"โŒ Failed to get server time: {e}") - server_time = int(time.time() * 1000) - print(f"Using local time: {server_time}") - - # Test order parameters (from MEXC documentation example) - params = { - 'symbol': 'MXUSDT', # Changed to API-supported symbol - 'side': 'BUY', - 'type': 'MARKET', - 'quantity': '1', # Small test quantity (1 MX token) - 'timestamp': server_time - } - - print("\n=== Testing Different Signature Methods ===") - - # Method 1: Sorted parameters with & separator (current approach) - print("\n1. Current approach (sorted with &):") - sorted_params = sorted(params.items()) - query_string1 = '&'.join([f"{key}={value}" for key, value in sorted_params]) - signature1 = hmac.new( - api_secret.encode('utf-8'), - query_string1.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - print(f"Query: {query_string1}") - print(f"Signature: {signature1}") - - # Method 2: URL encoded (like account info that works) - print("\n2. URL encoded approach:") - sorted_params = sorted(params.items()) - query_string2 = urlencode(sorted_params) - signature2 = hmac.new( - api_secret.encode('utf-8'), - query_string2.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - print(f"Query: {query_string2}") - print(f"Signature: {signature2}") - - # Method 3: MEXC documentation example format - print("\n3. MEXC docs example format:") - # From MEXC docs: symbol=BTCUSDT&side=BUY&type=LIMIT&quantity=1&price=11&recvWindow=5000×tamp=1644489390087 - query_string3 = f"symbol={params['symbol']}&side={params['side']}&type={params['type']}&quantity={params['quantity']}×tamp={params['timestamp']}" - signature3 = hmac.new( - api_secret.encode('utf-8'), - query_string3.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - print(f"Query: {query_string3}") - print(f"Signature: {signature3}") - - # Test all methods by making actual requests - print("\n=== Testing Actual Requests ===") - - methods = [ - ("Current approach", signature1, params), - ("URL encoded", signature2, params), - ("MEXC docs format", signature3, params) - ] - - for method_name, signature, test_params in methods: - print(f"\n{method_name}:") - test_params_copy = test_params.copy() - test_params_copy['signature'] = signature - - headers = {'X-MEXC-APIKEY': api_key} - - try: - response = requests.post( - 'https://api.mexc.com/api/v3/order', - params=test_params_copy, - headers=headers, - timeout=10 - ) - print(f"Status: {response.status_code}") - print(f"Response: {response.text[:200]}...") - - if response.status_code == 200: - print("โœ… SUCCESS!") - break - elif "Signature for this request is not valid" in response.text: - print("โŒ Invalid signature") - else: - print(f"โŒ Other error: {response.status_code}") - - except Exception as e: - print(f"โŒ Request failed: {e}") - -if __name__ == "__main__": - success = test_signature_generation() - - if success: - success = test_mexc_api_call() - - if success: - test_mexc_order_signature() - print("\n๐ŸŽ‰ All signature tests passed!") - else: - print("\n๐Ÿšจ Signature tests failed - check the output above") \ No newline at end of file diff --git a/tests/test_mexc_timestamp_debug.py b/tests/test_mexc_timestamp_debug.py deleted file mode 100644 index c906f0f..0000000 --- a/tests/test_mexc_timestamp_debug.py +++ /dev/null @@ -1,185 +0,0 @@ -""" -MEXC Timestamp and Signature Debug - -This script tests different timestamp and recvWindow combinations to fix the signature validation. -""" - -import os -import sys -import time -import hashlib -import hmac -from urllib.parse import urlencode -import requests - -# Add paths for imports -sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) - -def test_mexc_timestamp_debug(): - """Test different timestamp strategies""" - print("="*60) - print("MEXC TIMESTAMP AND SIGNATURE DEBUG") - print("="*60) - - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - if not api_key or not api_secret: - print("โŒ Missing API credentials") - return False - - base_url = "https://api.mexc.com" - api_version = "api/v3" - - # Test 1: Get server time directly - print("1. Getting server time...") - - try: - response = requests.get(f"{base_url}/{api_version}/time", timeout=10) - if response.status_code == 200: - server_time_data = response.json() - server_time = server_time_data['serverTime'] - local_time = int(time.time() * 1000) - time_diff = server_time - local_time - - print(f" Server time: {server_time}") - print(f" Local time: {local_time}") - print(f" Difference: {time_diff}ms") - - else: - print(f" โŒ Failed to get server time: {response.status_code}") - return False - except Exception as e: - print(f" โŒ Error getting server time: {e}") - return False - - # Test 2: Try different timestamp strategies - strategies = [ - ("Server time exactly", server_time), - ("Server time - 500ms", server_time - 500), - ("Server time - 1000ms", server_time - 1000), - ("Server time - 2000ms", server_time - 2000), - ("Local time", local_time), - ("Local time - 1000ms", local_time - 1000), - ] - - # Test with different recvWindow values - recv_windows = [5000, 10000, 30000, 60000] - - print(f"\n2. Testing different timestamp strategies and recvWindow values...") - - for strategy_name, timestamp in strategies: - print(f"\n Strategy: {strategy_name} (timestamp: {timestamp})") - - for recv_window in recv_windows: - print(f" Testing recvWindow: {recv_window}ms") - - # Test account info request - params = { - 'timestamp': timestamp, - 'recvWindow': recv_window - } - - # Generate signature - sorted_params = sorted(params.items()) - query_string = urlencode(sorted_params) - signature = hmac.new( - api_secret.encode('utf-8'), - query_string.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - - params['signature'] = signature - - # Make request - headers = {'X-MEXC-APIKEY': api_key} - url = f"{base_url}/{api_version}/account" - - try: - response = requests.get(url, params=params, headers=headers, timeout=10) - - if response.status_code == 200: - print(f" โœ… SUCCESS") - account_data = response.json() - print(f" Account type: {account_data.get('accountType', 'Unknown')}") - return True # Found working combination - else: - error_data = response.json() if 'application/json' in response.headers.get('content-type', '') else {'msg': response.text} - error_code = error_data.get('code', 'Unknown') - error_msg = error_data.get('msg', 'Unknown') - print(f" โŒ Error {error_code}: {error_msg}") - - except Exception as e: - print(f" โŒ Exception: {e}") - - print(f"\nโŒ No working timestamp/recvWindow combination found") - return False - -def test_minimal_signature(): - """Test with minimal parameters to isolate signature issues""" - print(f"\n3. Testing minimal signature generation...") - - api_key = os.getenv('MEXC_API_KEY') - api_secret = os.getenv('MEXC_SECRET_KEY') - - base_url = "https://api.mexc.com" - api_version = "api/v3" - - # Get fresh server time - try: - response = requests.get(f"{base_url}/{api_version}/time", timeout=10) - server_time = response.json()['serverTime'] - print(f" Fresh server time: {server_time}") - except Exception as e: - print(f" โŒ Failed to get server time: {e}") - return False - - # Test with absolute minimal parameters - minimal_params = { - 'timestamp': server_time - } - - # Generate signature with minimal params - sorted_params = sorted(minimal_params.items()) - query_string = urlencode(sorted_params) - signature = hmac.new( - api_secret.encode('utf-8'), - query_string.encode('utf-8'), - hashlib.sha256 - ).hexdigest() - - minimal_params['signature'] = signature - - print(f" Minimal params: {minimal_params}") - print(f" Query string: {query_string}") - print(f" Signature: {signature}") - - # Test account request with minimal params - headers = {'X-MEXC-APIKEY': api_key} - url = f"{base_url}/{api_version}/account" - - try: - response = requests.get(url, params=minimal_params, headers=headers, timeout=10) - - if response.status_code == 200: - print(f" โœ… Minimal signature works!") - return True - else: - error_data = response.json() if 'application/json' in response.headers.get('content-type', '') else {'msg': response.text} - print(f" โŒ Minimal signature failed: {error_data.get('code', 'Unknown')} - {error_data.get('msg', 'Unknown')}") - return False - - except Exception as e: - print(f" โŒ Exception with minimal signature: {e}") - return False - -if __name__ == "__main__": - success = test_mexc_timestamp_debug() - - if not success: - success = test_minimal_signature() - - if success: - print(f"\n๐ŸŽ‰ Found working MEXC configuration!") - else: - print(f"\n๐Ÿšจ MEXC signature/timestamp issue persists") \ No newline at end of file diff --git a/tests/test_mexc_trading_integration.py b/tests/test_mexc_trading_integration.py deleted file mode 100644 index 6e99196..0000000 --- a/tests/test_mexc_trading_integration.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -Test MEXC Trading Integration - -This script tests the integration between the enhanced orchestrator and MEXC trading executor. -It verifies that trading signals can be executed through the MEXC API with proper risk management. -""" - -import asyncio -import logging -import os -import sys -import time -from datetime import datetime - -# Add core directory to path -sys.path.append(os.path.join(os.path.dirname(__file__), 'core')) - -from core.trading_executor import TradingExecutor, Position, TradeRecord -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.data_provider import DataProvider -from core.config import get_config - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("test_mexc_trading.log"), - logging.StreamHandler() - ] -) -logger = logging.getLogger("mexc_trading_test") - -class TradingIntegrationTest: - """Test class for MEXC trading integration""" - - def __init__(self): - """Initialize the test environment""" - self.config = get_config() - self.data_provider = DataProvider() - self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) - self.trading_executor = TradingExecutor() - - # Test configuration - self.test_symbol = 'ETH/USDT' - self.test_confidence = 0.75 - - def test_trading_executor_initialization(self): - """Test that the trading executor initializes correctly""" - logger.info("Testing trading executor initialization...") - - try: - # Check configuration - assert self.trading_executor.mexc_config is not None - logger.info("โœ… MEXC configuration loaded") - - # Check dry run mode - assert self.trading_executor.dry_run == True - logger.info("โœ… Dry run mode enabled for safety") - - # Check position limits - max_position_value = self.trading_executor.mexc_config.get('max_position_value_usd', 1.0) - assert max_position_value == 1.0 - logger.info(f"โœ… Max position value set to ${max_position_value}") - - # Check safety features - assert self.trading_executor.mexc_config.get('emergency_stop', False) == False - logger.info("โœ… Emergency stop not active") - - return True - - except Exception as e: - logger.error(f"โŒ Trading executor initialization test failed: {e}") - return False - - def test_exchange_connection(self): - """Test connection to MEXC exchange""" - logger.info("Testing MEXC exchange connection...") - - try: - # Test ticker retrieval - ticker = self.trading_executor.exchange.get_ticker(self.test_symbol) - - if ticker: - logger.info(f"โœ… Successfully retrieved ticker for {self.test_symbol}") - logger.info(f" Current price: ${ticker['last']:.2f}") - logger.info(f" Bid: ${ticker['bid']:.2f}, Ask: ${ticker['ask']:.2f}") - return True - else: - logger.error(f"โŒ Failed to retrieve ticker for {self.test_symbol}") - return False - - except Exception as e: - logger.error(f"โŒ Exchange connection test failed: {e}") - return False - - def test_position_size_calculation(self): - """Test position size calculation with different confidence levels""" - logger.info("Testing position size calculation...") - - try: - test_price = 2500.0 - test_cases = [ - (0.5, "Medium confidence"), - (0.75, "High confidence"), - (0.9, "Very high confidence"), - (0.3, "Low confidence") - ] - - for confidence, description in test_cases: - position_value = self.trading_executor._calculate_position_size(confidence, test_price) - quantity = position_value / test_price - - logger.info(f" {description} ({confidence:.1f}): ${position_value:.2f} = {quantity:.6f} ETH") - - # Verify position value is within limits - max_value = self.trading_executor.mexc_config.get('max_position_value_usd', 1.0) - min_value = self.trading_executor.mexc_config.get('min_position_value_usd', 0.1) - - assert min_value <= position_value <= max_value - - logger.info("โœ… Position size calculation working correctly") - return True - - except Exception as e: - logger.error(f"โŒ Position size calculation test failed: {e}") - return False - - def test_dry_run_trading(self): - """Test dry run trading execution""" - logger.info("Testing dry run trading execution...") - - try: - # Get current price - ticker = self.trading_executor.exchange.get_ticker(self.test_symbol) - if not ticker: - logger.error("Cannot get current price for testing") - return False - - current_price = ticker['last'] - - # Test BUY signal - logger.info(f"Testing BUY signal for {self.test_symbol} at ${current_price:.2f}") - buy_success = self.trading_executor.execute_signal( - symbol=self.test_symbol, - action='BUY', - confidence=self.test_confidence, - current_price=current_price - ) - - if buy_success: - logger.info("โœ… BUY signal executed successfully in dry run mode") - - # Check position was created - positions = self.trading_executor.get_positions() - assert self.test_symbol in positions - position = positions[self.test_symbol] - logger.info(f" Position created: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}") - else: - logger.error("โŒ BUY signal execution failed") - return False - - # Wait a moment - time.sleep(1) - - # Test SELL signal - logger.info(f"Testing SELL signal for {self.test_symbol}") - sell_success = self.trading_executor.execute_signal( - symbol=self.test_symbol, - action='SELL', - confidence=self.test_confidence, - current_price=current_price * 1.01 # Simulate 1% price increase - ) - - if sell_success: - logger.info("โœ… SELL signal executed successfully in dry run mode") - - # Check position was closed - positions = self.trading_executor.get_positions() - assert self.test_symbol not in positions - - # Check trade history - trade_history = self.trading_executor.get_trade_history() - assert len(trade_history) > 0 - - last_trade = trade_history[-1] - logger.info(f" Trade completed: P&L ${last_trade.pnl:.2f}") - else: - logger.error("โŒ SELL signal execution failed") - return False - - logger.info("โœ… Dry run trading test completed successfully") - return True - - except Exception as e: - logger.error(f"โŒ Dry run trading test failed: {e}") - return False - - def test_safety_conditions(self): - """Test safety condition checks""" - logger.info("Testing safety condition checks...") - - try: - # Test symbol allowlist - disallowed_symbol = 'DOGE/USDT' - result = self.trading_executor._check_safety_conditions(disallowed_symbol, 'BUY') - if disallowed_symbol not in self.trading_executor.mexc_config.get('allowed_symbols', []): - assert result == False - logger.info("โœ… Symbol allowlist check working") - - # Test trade interval - # First trade should succeed - current_price = 2500.0 - self.trading_executor.execute_signal(self.test_symbol, 'BUY', 0.7, current_price) - - # Immediate second trade should fail due to interval - result = self.trading_executor._check_safety_conditions(self.test_symbol, 'BUY') - # Note: This might pass if interval is very short, which is fine for testing - - logger.info("โœ… Safety condition checks working") - return True - - except Exception as e: - logger.error(f"โŒ Safety condition test failed: {e}") - return False - - def test_daily_statistics(self): - """Test daily statistics tracking""" - logger.info("Testing daily statistics tracking...") - - try: - stats = self.trading_executor.get_daily_stats() - - required_keys = ['daily_trades', 'daily_loss', 'total_pnl', 'winning_trades', - 'losing_trades', 'win_rate', 'positions_count'] - - for key in required_keys: - assert key in stats - - logger.info("โœ… Daily statistics structure correct") - logger.info(f" Daily trades: {stats['daily_trades']}") - logger.info(f" Total P&L: ${stats['total_pnl']:.2f}") - logger.info(f" Win rate: {stats['win_rate']:.1%}") - - return True - - except Exception as e: - logger.error(f"โŒ Daily statistics test failed: {e}") - return False - - async def test_orchestrator_integration(self): - """Test integration with enhanced orchestrator""" - logger.info("Testing orchestrator integration...") - - try: - # Test that orchestrator can make decisions - decisions = await self.orchestrator.make_coordinated_decisions() - - logger.info(f"โœ… Orchestrator made decisions for {len(decisions)} symbols") - - for symbol, decision in decisions.items(): - if decision: - logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})") - - # Test executing the decision through trading executor - if decision.action != 'HOLD': - success = self.trading_executor.execute_signal( - symbol=symbol, - action=decision.action, - confidence=decision.confidence, - current_price=decision.price - ) - - if success: - logger.info(f" โœ… Successfully executed {decision.action} for {symbol}") - else: - logger.info(f" โš ๏ธ Trade execution blocked by safety conditions for {symbol}") - - return True - - except Exception as e: - logger.error(f"โŒ Orchestrator integration test failed: {e}") - return False - - def test_emergency_stop(self): - """Test emergency stop functionality""" - logger.info("Testing emergency stop functionality...") - - try: - # Create a test position first - current_price = 2500.0 - self.trading_executor.execute_signal(self.test_symbol, 'BUY', 0.7, current_price) - - # Verify position exists - positions_before = self.trading_executor.get_positions() - logger.info(f" Positions before emergency stop: {len(positions_before)}") - - # Trigger emergency stop - self.trading_executor.emergency_stop() - - # Verify trading is disabled - assert self.trading_executor.trading_enabled == False - logger.info("โœ… Trading disabled after emergency stop") - - # In dry run mode, positions should still be closed - positions_after = self.trading_executor.get_positions() - logger.info(f" Positions after emergency stop: {len(positions_after)}") - - return True - - except Exception as e: - logger.error(f"โŒ Emergency stop test failed: {e}") - return False - - async def run_all_tests(self): - """Run all integration tests""" - logger.info("๐Ÿš€ Starting MEXC Trading Integration Tests") - logger.info("=" * 60) - - tests = [ - ("Trading Executor Initialization", self.test_trading_executor_initialization), - ("Exchange Connection", self.test_exchange_connection), - ("Position Size Calculation", self.test_position_size_calculation), - ("Dry Run Trading", self.test_dry_run_trading), - ("Safety Conditions", self.test_safety_conditions), - ("Daily Statistics", self.test_daily_statistics), - ("Orchestrator Integration", self.test_orchestrator_integration), - ("Emergency Stop", self.test_emergency_stop), - ] - - passed = 0 - failed = 0 - - for test_name, test_func in tests: - logger.info(f"\n๐Ÿ“‹ Running test: {test_name}") - logger.info("-" * 40) - - try: - if asyncio.iscoroutinefunction(test_func): - result = await test_func() - else: - result = test_func() - - if result: - passed += 1 - logger.info(f"โœ… {test_name} PASSED") - else: - failed += 1 - logger.error(f"โŒ {test_name} FAILED") - - except Exception as e: - failed += 1 - logger.error(f"โŒ {test_name} FAILED with exception: {e}") - - logger.info("\n" + "=" * 60) - logger.info("๐Ÿ Test Results Summary") - logger.info(f"โœ… Passed: {passed}") - logger.info(f"โŒ Failed: {failed}") - logger.info(f"๐Ÿ“Š Success Rate: {passed/(passed+failed)*100:.1f}%") - - if failed == 0: - logger.info("๐ŸŽ‰ All tests passed! MEXC trading integration is ready.") - else: - logger.warning(f"โš ๏ธ {failed} test(s) failed. Please review and fix issues before live trading.") - - return failed == 0 - -async def main(): - """Main test function""" - test_runner = TradingIntegrationTest() - success = await test_runner.run_all_tests() - - if success: - logger.info("\n๐Ÿ”ง Next Steps:") - logger.info("1. Set up your MEXC API keys in .env file") - logger.info("2. Update config.yaml to enable trading (mexc_trading.enabled: true)") - logger.info("3. Consider disabling dry_run_mode for live trading") - logger.info("4. Start with small position sizes for initial live testing") - logger.info("5. Monitor the system closely during initial live trading") - - return success - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_minimal_dashboard.py b/tests/test_minimal_dashboard.py deleted file mode 100644 index 45c561b..0000000 --- a/tests/test_minimal_dashboard.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -""" -Minimal Dashboard Test - Debug startup issues -""" - -import logging -import sys -import traceback - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_imports(): - """Test all required imports""" - try: - logger.info("Testing imports...") - - # Core imports - from core.config import get_config - logger.info("โœ“ core.config imported") - - from core.data_provider import DataProvider - logger.info("โœ“ core.data_provider imported") - - # Dashboard imports - import dash - from dash import dcc, html - import plotly.graph_objects as go - logger.info("โœ“ Dash imports successful") - - # Try to import the dashboard - from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard - logger.info("โœ“ RealTimeScalpingDashboard imported") - - return True - - except Exception as e: - logger.error(f"Import error: {e}") - traceback.print_exc() - return False - -def test_dashboard_creation(): - """Test dashboard creation""" - try: - logger.info("Testing dashboard creation...") - - from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard - from core.data_provider import DataProvider - - # Create data provider - data_provider = DataProvider() - logger.info("โœ“ DataProvider created") - - # Create dashboard - dashboard = RealTimeScalpingDashboard(data_provider=data_provider) - logger.info("โœ“ Dashboard created successfully") - - return dashboard - - except Exception as e: - logger.error(f"Dashboard creation error: {e}") - traceback.print_exc() - return None - -def test_dashboard_run(): - """Test dashboard run""" - try: - logger.info("Testing dashboard run...") - - dashboard = test_dashboard_creation() - if not dashboard: - return False - - # Try to run dashboard - logger.info("Starting dashboard on port 8052...") - dashboard.run(host='127.0.0.1', port=8052, debug=True) - - return True - - except Exception as e: - logger.error(f"Dashboard run error: {e}") - traceback.print_exc() - return False - -def main(): - """Main test function""" - logger.info("=== MINIMAL DASHBOARD TEST ===") - - # Test 1: Imports - if not test_imports(): - logger.error("Import test failed!") - sys.exit(1) - - # Test 2: Dashboard creation - dashboard = test_dashboard_creation() - if not dashboard: - logger.error("Dashboard creation test failed!") - sys.exit(1) - - # Test 3: Dashboard run - logger.info("All tests passed! Starting dashboard...") - test_dashboard_run() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_minimal_trading.py b/tests/test_minimal_trading.py deleted file mode 100644 index 6119dc4..0000000 --- a/tests/test_minimal_trading.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -""" -Minimal Trading Test -Test basic trading functionality with simplified decision logic -""" - -import logging -import asyncio -from datetime import datetime -import pandas as pd -import numpy as np - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -async def test_minimal_trading(): - """Test minimal trading with lowered thresholds""" - logger.info("=== MINIMAL TRADING TEST ===") - - try: - from core.config import get_config - from core.data_provider import DataProvider - from core.trading_executor import TradingExecutor - - # Initialize with minimal components - config = get_config() - data_provider = DataProvider() - trading_executor = TradingExecutor() - - logger.info("โœ… Basic components initialized") - - # Test data availability - symbol = 'ETH/USDT' - data = data_provider.get_historical_data(symbol, '1m', limit=20) - - if data is None or data.empty: - logger.error("โŒ No data available for minimal test") - return - - current_price = float(data['close'].iloc[-1]) - logger.info(f"โœ… Current {symbol} price: ${current_price:.2f}") - - # Generate simple trading signal - price_change = data['close'].pct_change().iloc[-5:].mean() - - # Simple momentum signal - if price_change > 0.001: # 0.1% positive momentum - action = 'BUY' - confidence = 0.6 # Above 35% threshold - reason = f"Positive momentum: {price_change:.1%}" - elif price_change < -0.001: # 0.1% negative momentum - action = 'SELL' - confidence = 0.6 # Above 35% threshold - reason = f"Negative momentum: {price_change:.1%}" - else: - action = 'HOLD' - confidence = 0.3 - reason = "Neutral momentum" - - logger.info(f"๐Ÿ“ˆ Signal: {action} with {confidence:.1%} confidence - {reason}") - - # Test if we would execute this trade - if confidence > 0.35: # Our new threshold - logger.info("โœ… Signal WOULD trigger trade execution") - - # Simulate position sizing - position_size = 0.01 # 0.01 ETH - estimated_value = position_size * current_price - - logger.info(f"๐Ÿ“Š Would trade {position_size} ETH (~${estimated_value:.2f})") - - # Test trading executor (simulation mode) - if hasattr(trading_executor, 'simulation_mode'): - trading_executor.simulation_mode = True - - logger.info("๐ŸŽฏ Trading signal meets threshold - system operational") - - else: - logger.warning(f"โŒ Signal below threshold ({confidence:.1%} < 35%)") - - # Test multiple timeframes - logger.info("\n=== MULTI-TIMEFRAME TEST ===") - timeframes = ['1m', '5m', '1h'] - signals = [] - - for tf in timeframes: - try: - tf_data = data_provider.get_historical_data(symbol, tf, limit=10) - if tf_data is not None and not tf_data.empty: - tf_change = tf_data['close'].pct_change().iloc[-3:].mean() - tf_confidence = min(0.8, abs(tf_change) * 100) - - signals.append({ - 'timeframe': tf, - 'change': tf_change, - 'confidence': tf_confidence - }) - - logger.info(f" {tf}: {tf_change:.2%} change, {tf_confidence:.1%} confidence") - except Exception as e: - logger.warning(f" {tf}: Error - {e}") - - # Combined signal - if signals: - avg_confidence = np.mean([s['confidence'] for s in signals]) - logger.info(f"๐Ÿ“Š Average multi-timeframe confidence: {avg_confidence:.1%}") - - if avg_confidence > 0.35: - logger.info("โœ… Multi-timeframe signal would trigger trade") - else: - logger.warning("โŒ Multi-timeframe signal below threshold") - - logger.info("\n=== RECOMMENDATIONS ===") - logger.info("1. โœ… Data flow is working correctly") - logger.info("2. โœ… Price data is fresh and accurate") - logger.info("3. โœ… Confidence thresholds are now more reasonable (35%)") - logger.info("4. โš ๏ธ Complex cross-asset logic has bugs - use simple momentum") - logger.info("5. ๐ŸŽฏ System can generate trading signals - test with real orchestrator") - - except Exception as e: - logger.error(f"โŒ Minimal trading test failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - asyncio.run(test_minimal_trading()) \ No newline at end of file diff --git a/tests/test_model_persistence.py b/tests/test_model_persistence.py deleted file mode 100644 index ba0b580..0000000 --- a/tests/test_model_persistence.py +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python -""" -Comprehensive test suite for model persistence and training functionality -""" - -import os -import sys -import unittest -import tempfile -import logging -import torch -import numpy as np -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from utils.model_utils import robust_save, robust_load, get_model_info, verify_save_load_cycle - -# Configure logging for tests -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -class MockAgent: - """Mock agent class for testing model persistence""" - - def __init__(self, state_size=64, action_size=4, hidden_size=256): - self.state_size = state_size - self.action_size = action_size - self.hidden_size = hidden_size - self.epsilon = 0.1 - - # Create simple mock networks - self.policy_net = torch.nn.Sequential( - torch.nn.Linear(state_size, hidden_size), - torch.nn.ReLU(), - torch.nn.Linear(hidden_size, action_size) - ) - - self.target_net = torch.nn.Sequential( - torch.nn.Linear(state_size, hidden_size), - torch.nn.ReLU(), - torch.nn.Linear(hidden_size, action_size) - ) - - self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001) - -class TestModelPersistence(unittest.TestCase): - """Test suite for model saving and loading functionality""" - - def setUp(self): - """Set up test fixtures""" - self.temp_dir = tempfile.mkdtemp() - self.test_agent = MockAgent() - - def tearDown(self): - """Clean up test fixtures""" - import shutil - shutil.rmtree(self.temp_dir, ignore_errors=True) - - def test_robust_save_basic(self): - """Test basic robust save functionality""" - save_path = os.path.join(self.temp_dir, "test_model.pt") - - success = robust_save(self.test_agent, save_path) - self.assertTrue(success, "Robust save should succeed") - self.assertTrue(os.path.exists(save_path), "Model file should exist") - self.assertGreater(os.path.getsize(save_path), 0, "Model file should not be empty") - - def test_robust_save_without_optimizer(self): - """Test robust save without optimizer state""" - save_path = os.path.join(self.temp_dir, "test_model_no_opt.pt") - - success = robust_save(self.test_agent, save_path, include_optimizer=False) - self.assertTrue(success, "Robust save without optimizer should succeed") - - # Verify that optimizer state is not included - checkpoint = torch.load(save_path, map_location='cpu') - self.assertNotIn('optimizer', checkpoint, "Optimizer state should not be saved") - self.assertIn('policy_net', checkpoint, "Policy network should be saved") - - def test_robust_load_basic(self): - """Test basic robust load functionality""" - save_path = os.path.join(self.temp_dir, "test_model.pt") - - # Save first - success = robust_save(self.test_agent, save_path) - self.assertTrue(success, "Save should succeed") - - # Create new agent and load - new_agent = MockAgent() - success = robust_load(new_agent, save_path) - self.assertTrue(success, "Load should succeed") - - # Verify epsilon was loaded - self.assertEqual(new_agent.epsilon, self.test_agent.epsilon, "Epsilon should match") - - def test_get_model_info(self): - """Test model info extraction""" - save_path = os.path.join(self.temp_dir, "test_model.pt") - - # Test non-existent file - info = get_model_info(save_path) - self.assertFalse(info['exists'], "Non-existent file should return exists=False") - - # Save model and test info - robust_save(self.test_agent, save_path) - info = get_model_info(save_path) - - self.assertTrue(info['exists'], "Existing file should return exists=True") - self.assertGreater(info['size_bytes'], 0, "File size should be greater than 0") - self.assertTrue(info['has_optimizer'], "Should detect optimizer in checkpoint") - self.assertEqual(info['parameters']['state_size'], self.test_agent.state_size) - self.assertEqual(info['parameters']['action_size'], self.test_agent.action_size) - - def test_save_load_cycle_verification(self): - """Test save/load cycle verification""" - test_path = os.path.join(self.temp_dir, "cycle_test.pt") - - success = verify_save_load_cycle(self.test_agent, test_path) - self.assertTrue(success, "Save/load cycle should succeed") - - # File should be cleaned up after verification - self.assertFalse(os.path.exists(test_path), "Test file should be cleaned up") - - def test_multiple_save_methods(self): - """Test that different save methods all work""" - methods = ['regular', 'no_optimizer', 'pickle2'] - - for method in methods: - with self.subTest(method=method): - save_path = os.path.join(self.temp_dir, f"test_{method}.pt") - - if method == 'regular': - success = robust_save(self.test_agent, save_path) - elif method == 'no_optimizer': - success = robust_save(self.test_agent, save_path, include_optimizer=False) - elif method == 'pickle2': - # This would be tested by the robust_save fallback mechanism - success = robust_save(self.test_agent, save_path) - - self.assertTrue(success, f"{method} save should succeed") - self.assertTrue(os.path.exists(save_path), f"{method} save should create file") - -class TestTrainingMetrics(unittest.TestCase): - """Test suite for training metrics and monitoring functionality""" - - def test_signal_distribution_calculation(self): - """Test signal distribution calculation""" - # Mock predictions - predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY - - buy_count = np.sum(predictions == 2) - sell_count = np.sum(predictions == 0) - hold_count = np.sum(predictions == 1) - total = len(predictions) - - distribution = { - "BUY": buy_count / total, - "SELL": sell_count / total, - "HOLD": hold_count / total - } - - self.assertAlmostEqual(distribution["BUY"], 0.3, places=2) - self.assertAlmostEqual(distribution["SELL"], 0.3, places=2) - self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2) - self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2) - - def test_metrics_tracking_structure(self): - """Test metrics history structure for training monitoring""" - metrics_history = { - "epoch": [], - "train_loss": [], - "val_loss": [], - "train_acc": [], - "val_acc": [], - "train_pnl": [], - "val_pnl": [], - "train_win_rate": [], - "val_win_rate": [], - "signal_distribution": [] - } - - # Simulate adding metrics for one epoch - metrics_history["epoch"].append(1) - metrics_history["train_loss"].append(0.5) - metrics_history["val_loss"].append(0.6) - metrics_history["train_acc"].append(0.7) - metrics_history["val_acc"].append(0.65) - metrics_history["train_pnl"].append(0.1) - metrics_history["val_pnl"].append(0.08) - metrics_history["train_win_rate"].append(0.6) - metrics_history["val_win_rate"].append(0.55) - metrics_history["signal_distribution"].append({"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4}) - - # Verify structure - self.assertEqual(len(metrics_history["epoch"]), 1) - self.assertEqual(metrics_history["epoch"][0], 1) - self.assertIsInstance(metrics_history["signal_distribution"][0], dict) - self.assertIn("BUY", metrics_history["signal_distribution"][0]) - -class TestModelArchitecture(unittest.TestCase): - """Test suite for model architecture verification""" - - def test_model_parameter_consistency(self): - """Test that model parameters are consistent after save/load""" - agent = MockAgent(state_size=32, action_size=3, hidden_size=128) - - with tempfile.TemporaryDirectory() as temp_dir: - save_path = os.path.join(temp_dir, "consistency_test.pt") - - # Save model - robust_save(agent, save_path) - - # Load into new model with same architecture - new_agent = MockAgent(state_size=32, action_size=3, hidden_size=128) - robust_load(new_agent, save_path) - - # Verify parameters match - self.assertEqual(new_agent.state_size, agent.state_size) - self.assertEqual(new_agent.action_size, agent.action_size) - self.assertEqual(new_agent.hidden_size, agent.hidden_size) - self.assertEqual(new_agent.epsilon, agent.epsilon) - - def test_model_forward_pass(self): - """Test that model can perform forward pass after load""" - agent = MockAgent() - - with tempfile.TemporaryDirectory() as temp_dir: - save_path = os.path.join(temp_dir, "forward_test.pt") - - # Create test input - test_input = torch.randn(1, agent.state_size) - - # Get original output - original_output = agent.policy_net(test_input) - - # Save and load - robust_save(agent, save_path) - new_agent = MockAgent() - robust_load(new_agent, save_path) - - # Test forward pass works - new_output = new_agent.policy_net(test_input) - - self.assertEqual(new_output.shape, original_output.shape) - # Outputs should be identical since we loaded the same weights - torch.testing.assert_close(new_output, original_output) - -def run_all_tests(): - """Run all test suites""" - test_suites = [ - unittest.TestLoader().loadTestsFromTestCase(TestModelPersistence), - unittest.TestLoader().loadTestsFromTestCase(TestTrainingMetrics), - unittest.TestLoader().loadTestsFromTestCase(TestModelArchitecture) - ] - - combined_suite = unittest.TestSuite(test_suites) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(combined_suite) - - return result.wasSuccessful() - -if __name__ == "__main__": - logger.info("Running comprehensive model persistence and training tests...") - success = run_all_tests() - - if success: - logger.info("All tests passed!") - sys.exit(0) - else: - logger.error("Some tests failed!") - sys.exit(1) \ No newline at end of file diff --git a/tests/test_multi_exchange_cob.py b/tests/test_multi_exchange_cob.py deleted file mode 100644 index 0c035da..0000000 --- a/tests/test_multi_exchange_cob.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -Test Multi-Exchange Consolidated Order Book (COB) Provider - -This script demonstrates the functionality of the new multi-exchange COB data provider: -1. Real-time order book aggregation from multiple exchanges -2. Fine-grain price bucket generation -3. CNN/DQN feature generation -4. Dashboard integration -5. Market analysis and signal generation - -Run this to test the COB provider with live data streams. -""" - -import asyncio -import logging -import time -from datetime import datetime -from core.multi_exchange_cob_provider import MultiExchangeCOBProvider -from core.cob_integration import COBIntegration -from core.data_provider import DataProvider - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -class COBTester: - """Test harness for Multi-Exchange COB Provider""" - - def __init__(self): - self.symbols = ['BTC/USDT', 'ETH/USDT'] - self.data_provider = None - self.cob_integration = None - self.test_duration = 300 # 5 minutes - - # Statistics tracking - self.stats = { - 'cob_updates_received': 0, - 'bucket_updates_received': 0, - 'cnn_features_generated': 0, - 'dqn_features_generated': 0, - 'signals_generated': 0, - 'start_time': None - } - - async def run_test(self): - """Run comprehensive COB provider test""" - logger.info("Starting Multi-Exchange COB Provider Test") - logger.info(f"Testing symbols: {self.symbols}") - logger.info(f"Test duration: {self.test_duration} seconds") - - try: - # Initialize components - await self._initialize_components() - - # Run test scenarios - await self._run_basic_functionality_test() - await self._run_feature_generation_test() - await self._run_dashboard_integration_test() - await self._run_signal_analysis_test() - - # Monitor for specified duration - await self._monitor_live_data() - - # Generate final report - self._generate_test_report() - - except Exception as e: - logger.error(f"Test failed: {e}") - finally: - await self._cleanup() - - async def _initialize_components(self): - """Initialize COB provider and integration components""" - logger.info("Initializing COB components...") - - # Create data provider (optional - for integration testing) - self.data_provider = DataProvider(symbols=self.symbols) - - # Create COB integration - self.cob_integration = COBIntegration( - data_provider=self.data_provider, - symbols=self.symbols - ) - - # Register test callbacks - self.cob_integration.add_cnn_callback(self._cnn_callback) - self.cob_integration.add_dqn_callback(self._dqn_callback) - self.cob_integration.add_dashboard_callback(self._dashboard_callback) - - # Start COB integration - await self.cob_integration.start() - - # Allow time for connections - await asyncio.sleep(5) - - self.stats['start_time'] = datetime.now() - logger.info("COB components initialized successfully") - - async def _run_basic_functionality_test(self): - """Test basic COB provider functionality""" - logger.info("Testing basic COB functionality...") - - # Wait for order book data - await asyncio.sleep(10) - - for symbol in self.symbols: - # Test consolidated order book retrieval - cob_snapshot = self.cob_integration.get_cob_snapshot(symbol) - if cob_snapshot: - logger.info(f"{symbol} COB Status:") - logger.info(f" Exchanges active: {cob_snapshot.exchanges_active}") - logger.info(f" Volume weighted mid: ${cob_snapshot.volume_weighted_mid:.2f}") - logger.info(f" Spread: {cob_snapshot.spread_bps:.2f} bps") - logger.info(f" Bid liquidity: ${cob_snapshot.total_bid_liquidity:,.0f}") - logger.info(f" Ask liquidity: ${cob_snapshot.total_ask_liquidity:,.0f}") - logger.info(f" Liquidity imbalance: {cob_snapshot.liquidity_imbalance:.3f}") - - # Test price buckets - price_buckets = self.cob_integration.get_price_buckets(symbol) - if price_buckets: - bid_buckets = len(price_buckets.get('bids', {})) - ask_buckets = len(price_buckets.get('asks', {})) - logger.info(f" Price buckets: {bid_buckets} bids, {ask_buckets} asks") - - # Test exchange breakdown - exchange_breakdown = self.cob_integration.get_exchange_breakdown(symbol) - if exchange_breakdown: - logger.info(f" Exchange breakdown:") - for exchange, data in exchange_breakdown.items(): - market_share = data.get('market_share', 0) * 100 - logger.info(f" {exchange}: {market_share:.1f}% market share") - else: - logger.warning(f"No COB data available for {symbol}") - - logger.info("Basic functionality test completed") - - async def _run_feature_generation_test(self): - """Test CNN and DQN feature generation""" - logger.info("Testing feature generation...") - - for symbol in self.symbols: - # Test CNN features - cnn_features = self.cob_integration.get_cob_features(symbol) - if cnn_features is not None: - logger.info(f"{symbol} CNN features: shape={cnn_features.shape}, " - f"min={cnn_features.min():.4f}, max={cnn_features.max():.4f}") - else: - logger.warning(f"No CNN features available for {symbol}") - - # Test market depth analysis - depth_analysis = self.cob_integration.get_market_depth_analysis(symbol) - if depth_analysis: - logger.info(f"{symbol} Market Depth Analysis:") - logger.info(f" Depth levels: {depth_analysis['depth_analysis']['bid_levels']} bids, " - f"{depth_analysis['depth_analysis']['ask_levels']} asks") - - dominant_exchanges = depth_analysis['depth_analysis'].get('dominant_exchanges', {}) - logger.info(f" Dominant exchanges: {dominant_exchanges}") - - logger.info("Feature generation test completed") - - async def _run_dashboard_integration_test(self): - """Test dashboard data generation""" - logger.info("Testing dashboard integration...") - - # Dashboard integration is tested via callbacks - # Statistics are tracked in the callback functions - await asyncio.sleep(5) - - logger.info("Dashboard integration test completed") - - async def _run_signal_analysis_test(self): - """Test signal generation and analysis""" - logger.info("Testing signal analysis...") - - for symbol in self.symbols: - # Get recent signals - recent_signals = self.cob_integration.get_recent_signals(symbol, count=10) - logger.info(f"{symbol} recent signals: {len(recent_signals)} generated") - - for signal in recent_signals[-3:]: # Show last 3 signals - logger.info(f" Signal: {signal.get('type')} - {signal.get('side')} - " - f"Confidence: {signal.get('confidence', 0):.3f}") - - logger.info("Signal analysis test completed") - - async def _monitor_live_data(self): - """Monitor live data for the specified duration""" - logger.info(f"Monitoring live data for {self.test_duration} seconds...") - - start_time = time.time() - last_stats_time = start_time - - while time.time() - start_time < self.test_duration: - # Print periodic statistics - current_time = time.time() - if current_time - last_stats_time >= 30: # Every 30 seconds - self._print_periodic_stats() - last_stats_time = current_time - - await asyncio.sleep(1) - - logger.info("Live data monitoring completed") - - def _print_periodic_stats(self): - """Print periodic statistics during monitoring""" - elapsed = (datetime.now() - self.stats['start_time']).total_seconds() - - logger.info("Periodic Statistics:") - logger.info(f" Elapsed time: {elapsed:.0f} seconds") - logger.info(f" COB updates: {self.stats['cob_updates_received']}") - logger.info(f" Bucket updates: {self.stats['bucket_updates_received']}") - logger.info(f" CNN features: {self.stats['cnn_features_generated']}") - logger.info(f" DQN features: {self.stats['dqn_features_generated']}") - logger.info(f" Signals: {self.stats['signals_generated']}") - - # Calculate rates - if elapsed > 0: - cob_rate = self.stats['cob_updates_received'] / elapsed - logger.info(f" COB update rate: {cob_rate:.2f}/sec") - - def _generate_test_report(self): - """Generate final test report""" - elapsed = (datetime.now() - self.stats['start_time']).total_seconds() - - logger.info("=" * 60) - logger.info("MULTI-EXCHANGE COB PROVIDER TEST REPORT") - logger.info("=" * 60) - logger.info(f"Test Duration: {elapsed:.0f} seconds") - logger.info(f"Symbols Tested: {', '.join(self.symbols)}") - logger.info("") - - # Data Reception Statistics - logger.info("Data Reception:") - logger.info(f" COB Updates Received: {self.stats['cob_updates_received']}") - logger.info(f" Bucket Updates Received: {self.stats['bucket_updates_received']}") - logger.info(f" Average COB Rate: {self.stats['cob_updates_received'] / elapsed:.2f}/sec") - logger.info("") - - # Feature Generation Statistics - logger.info("Feature Generation:") - logger.info(f" CNN Features Generated: {self.stats['cnn_features_generated']}") - logger.info(f" DQN Features Generated: {self.stats['dqn_features_generated']}") - logger.info("") - - # Signal Generation Statistics - logger.info("Signal Analysis:") - logger.info(f" Signals Generated: {self.stats['signals_generated']}") - logger.info("") - - # Component Statistics - cob_stats = self.cob_integration.get_statistics() - logger.info("Component Statistics:") - logger.info(f" Active Exchanges: {', '.join(cob_stats.get('active_exchanges', []))}") - logger.info(f" Streaming Status: {cob_stats.get('is_streaming', False)}") - logger.info(f" Bucket Size: {cob_stats.get('bucket_size_bps', 0)} bps") - logger.info(f" Average Processing Time: {cob_stats.get('avg_processing_time_ms', 0):.2f} ms") - logger.info("") - - # Per-Symbol Analysis - logger.info("Per-Symbol Analysis:") - for symbol in self.symbols: - cob_snapshot = self.cob_integration.get_cob_snapshot(symbol) - if cob_snapshot: - logger.info(f" {symbol}:") - logger.info(f" Active Exchanges: {len(cob_snapshot.exchanges_active)}") - logger.info(f" Spread: {cob_snapshot.spread_bps:.2f} bps") - logger.info(f" Total Liquidity: ${(cob_snapshot.total_bid_liquidity + cob_snapshot.total_ask_liquidity):,.0f}") - - recent_signals = self.cob_integration.get_recent_signals(symbol) - logger.info(f" Signals Generated: {len(recent_signals)}") - - logger.info("=" * 60) - logger.info("Test completed successfully!") - - async def _cleanup(self): - """Cleanup resources""" - logger.info("Cleaning up resources...") - - if self.cob_integration: - await self.cob_integration.stop() - - if self.data_provider and hasattr(self.data_provider, 'stop_real_time_streaming'): - await self.data_provider.stop_real_time_streaming() - - logger.info("Cleanup completed") - - # Callback functions for testing - - def _cnn_callback(self, symbol: str, data: dict): - """CNN feature callback for testing""" - self.stats['cnn_features_generated'] += 1 - if self.stats['cnn_features_generated'] % 100 == 0: - logger.debug(f"CNN features generated: {self.stats['cnn_features_generated']}") - - def _dqn_callback(self, symbol: str, data: dict): - """DQN feature callback for testing""" - self.stats['dqn_features_generated'] += 1 - if self.stats['dqn_features_generated'] % 100 == 0: - logger.debug(f"DQN features generated: {self.stats['dqn_features_generated']}") - - def _dashboard_callback(self, symbol: str, data: dict): - """Dashboard data callback for testing""" - self.stats['cob_updates_received'] += 1 - - # Check for signals in dashboard data - signals = data.get('recent_signals', []) - self.stats['signals_generated'] += len(signals) - -async def main(): - """Main test function""" - logger.info("Multi-Exchange COB Provider Test Starting...") - - try: - tester = COBTester() - await tester.run_test() - except KeyboardInterrupt: - logger.info("Test interrupted by user") - except Exception as e: - logger.error(f"Test failed with error: {e}") - raise - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_negative_case_training.py b/tests/test_negative_case_training.py deleted file mode 100644 index 0fc41d4..0000000 --- a/tests/test_negative_case_training.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for negative case training functionality - -This script tests: -1. Negative case trainer initialization -2. Adding losing trades for intensive training -3. Storage in testcases/negative folder -4. Simultaneous inference and training -5. 500x leverage training case generation -""" - -import logging -import time -from datetime import datetime -from core.negative_case_trainer import NegativeCaseTrainer, NegativeCase -from core.trading_action import TradingAction - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_negative_case_trainer(): - """Test negative case trainer functionality""" - print("๐Ÿ”ด Testing Negative Case Trainer for Intensive Training on Losses") - print("=" * 70) - - # Test 1: Initialize trainer - print("\n1. Initializing Negative Case Trainer...") - trainer = NegativeCaseTrainer() - print(f"โœ… Trainer initialized with storage at: {trainer.storage_dir}") - print(f"โœ… Background training thread started: {trainer.training_thread.is_alive()}") - - # Test 2: Create a losing trade scenario - print("\n2. Creating losing trade scenarios...") - - # Scenario 1: Small loss (1% with 500x leverage = 500% loss) - trade_info_1 = { - 'timestamp': datetime.now(), - 'symbol': 'ETH/USDT', - 'action': 'BUY', - 'price': 3000.0, - 'size': 0.1, - 'value': 300.0, - 'confidence': 0.8, - 'pnl': -3.0 # $3 loss on $300 position = 1% loss - } - - market_data_1 = { - 'exit_price': 2970.0, # 1% drop - 'state_before': { - 'volatility': 2.5, - 'momentum': 0.5, - 'volume_ratio': 1.2 - }, - 'state_after': { - 'volatility': 3.0, - 'momentum': -1.0, - 'volume_ratio': 0.8 - }, - 'tick_data': [], - 'technical_indicators': { - 'rsi': 65, - 'macd': 0.5 - } - } - - case_id_1 = trainer.add_losing_trade(trade_info_1, market_data_1) - print(f"โœ… Added small loss case: {case_id_1}") - - # Scenario 2: Large loss (5% with 500x leverage = 2500% loss) - trade_info_2 = { - 'timestamp': datetime.now(), - 'symbol': 'ETH/USDT', - 'action': 'SELL', - 'price': 3000.0, - 'size': 0.2, - 'value': 600.0, - 'confidence': 0.9, - 'pnl': -30.0 # $30 loss on $600 position = 5% loss - } - - market_data_2 = { - 'exit_price': 3150.0, # 5% rise (bad for short) - 'state_before': { - 'volatility': 1.8, - 'momentum': -0.3, - 'volume_ratio': 0.9 - }, - 'state_after': { - 'volatility': 4.2, - 'momentum': 2.5, - 'volume_ratio': 1.8 - }, - 'tick_data': [], - 'technical_indicators': { - 'rsi': 35, - 'macd': -0.8 - } - } - - case_id_2 = trainer.add_losing_trade(trade_info_2, market_data_2) - print(f"โœ… Added large loss case: {case_id_2}") - - # Test 3: Check training stats - print("\n3. Checking training statistics...") - stats = trainer.get_training_stats() - print(f"โœ… Total negative cases: {stats['total_negative_cases']}") - print(f"โœ… Cases in training queue: {stats['cases_in_queue']}") - print(f"โœ… High priority cases: {stats['high_priority_cases']}") - print(f"โœ… Training active: {stats['training_active']}") - print(f"โœ… Storage directory: {stats['storage_directory']}") - - # Test 4: Check recent lessons - print("\n4. Recent lessons learned...") - lessons = trainer.get_recent_lessons(3) - for i, lesson in enumerate(lessons, 1): - print(f"โœ… Lesson {i}: {lesson}") - - # Test 5: Test simultaneous inference capability - print("\n5. Testing simultaneous inference and training...") - for i in range(5): - can_inference = trainer.can_inference_proceed() - print(f"โœ… Inference check {i+1}: {'ALLOWED' if can_inference else 'BLOCKED'}") - time.sleep(0.5) - - # Test 6: Wait for some training to complete - print("\n6. Waiting for intensive training to process cases...") - time.sleep(3) # Wait for background training - - # Check updated stats - updated_stats = trainer.get_training_stats() - print(f"โœ… Cases processed: {updated_stats['total_cases_processed']}") - print(f"โœ… Total training time: {updated_stats['total_training_time']:.2f}s") - print(f"โœ… Avg accuracy improvement: {updated_stats['avg_accuracy_improvement']:.1%}") - - # Test 7: 500x leverage training case analysis - print("\n7. 500x Leverage Training Case Analysis...") - print("๐Ÿ’ก With 0% fees, any move >0.1% is profitable at 500x leverage:") - - test_moves = [0.05, 0.1, 0.15, 0.2, 0.5, 1.0] # Price change percentages - for move_pct in test_moves: - leverage_profit = move_pct * 500 - profitable = move_pct >= 0.1 - status = "โœ… PROFITABLE" if profitable else "โŒ TOO SMALL" - print(f" {move_pct:+.2f}% move = {leverage_profit:+.1f}% @ 500x leverage - {status}") - - print("\n๐Ÿ”ด PRIORITY: Losing trades trigger intensive RL retraining") - print("๐Ÿš€ System optimized for fast trading with 500x leverage and 0% fees") - print("โšก Training cases generated for all moves >0.1% to maximize profit") - - return trainer - -def test_integration_with_enhanced_dashboard(): - """Test integration with enhanced dashboard""" - print("\n" + "=" * 70) - print("๐Ÿ”— Testing Integration with Enhanced Dashboard") - print("=" * 70) - - try: - from web.old_archived.enhanced_scalping_dashboard import EnhancedScalpingDashboard - from core.data_provider import DataProvider - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - - # Create components - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator(data_provider) - dashboard = EnhancedScalpingDashboard(data_provider, orchestrator) - - print("โœ… Enhanced dashboard created successfully") - print(f"โœ… Orchestrator has negative case trainer: {hasattr(orchestrator, 'negative_case_trainer')}") - print(f"โœ… Trading session has orchestrator reference: {hasattr(dashboard.trading_session, 'orchestrator')}") - - # Test negative case trainer access - if hasattr(orchestrator, 'negative_case_trainer'): - trainer_stats = orchestrator.negative_case_trainer.get_training_stats() - print(f"โœ… Negative case trainer accessible with {trainer_stats['total_negative_cases']} cases") - - return True - - except Exception as e: - print(f"โŒ Integration test failed: {e}") - return False - -if __name__ == "__main__": - print("๐Ÿ”ด NEGATIVE CASE TRAINING TEST SUITE") - print("Focus: Learning from losses to prevent future mistakes") - print("Features: 500x leverage optimization, 0% fee advantage, intensive retraining") - - try: - # Test negative case trainer - trainer = test_negative_case_trainer() - - # Test integration - integration_success = test_integration_with_enhanced_dashboard() - - print("\n" + "=" * 70) - print("๐Ÿ“Š TEST SUMMARY") - print("=" * 70) - print("โœ… Negative case trainer: WORKING") - print("โœ… Intensive training on losses: ACTIVE") - print("โœ… Storage in testcases/negative: WORKING") - print("โœ… Simultaneous inference/training: SUPPORTED") - print("โœ… 500x leverage optimization: IMPLEMENTED") - print(f"โœ… Enhanced dashboard integration: {'WORKING' if integration_success else 'NEEDS ATTENTION'}") - - print("\n๐ŸŽฏ SYSTEM READY FOR INTENSIVE LOSS-BASED LEARNING") - print("๐Ÿ’ช Every losing trade makes the system stronger!") - - except Exception as e: - print(f"\nโŒ Test suite failed: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/tests/test_nn_driven_trading.py b/tests/test_nn_driven_trading.py deleted file mode 100644 index 057d80d..0000000 --- a/tests/test_nn_driven_trading.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 -""" -Test NN-Driven Trading System -Demonstrates how the system now makes decisions using Neural Networks instead of algorithms -""" - -import logging -import asyncio -from datetime import datetime -import numpy as np - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -async def test_nn_driven_system(): - """Test the NN-driven trading system""" - logger.info("=== TESTING NN-DRIVEN TRADING SYSTEM ===") - - try: - # Import core components - from core.config import get_config - from core.data_provider import DataProvider - from core.enhanced_orchestrator import EnhancedTradingOrchestrator - from core.nn_decision_fusion import ModelPrediction, MarketContext - - # Initialize components - config = get_config() - data_provider = DataProvider() - - # Initialize NN-driven orchestrator - orchestrator = EnhancedTradingOrchestrator( - data_provider=data_provider, - symbols=['ETH/USDT', 'BTC/USDT'], - enhanced_rl_training=True - ) - - logger.info("โœ… NN-driven orchestrator initialized") - - # Test 1: Add mock CNN prediction - cnn_prediction = ModelPrediction( - model_name="williams_cnn", - prediction_type="direction", - value=0.6, # Bullish signal - confidence=0.8, - timestamp=datetime.now(), - metadata={'timeframe': '1h', 'feature_importance': [0.2, 0.3, 0.5]} - ) - - orchestrator.neural_fusion.add_prediction(cnn_prediction) - logger.info("๐Ÿ”ฎ Added CNN prediction: BULLISH (0.6) with 80% confidence") - - # Test 2: Add mock RL prediction - rl_prediction = ModelPrediction( - model_name="dqn_agent", - prediction_type="action", - value=0.4, # Moderate buy signal - confidence=0.7, - timestamp=datetime.now(), - metadata={'action_probs': [0.4, 0.2, 0.4]} # [BUY, SELL, HOLD] - ) - - orchestrator.neural_fusion.add_prediction(rl_prediction) - logger.info("๐Ÿ”ฎ Added RL prediction: MODERATE_BUY (0.4) with 70% confidence") - - # Test 3: Add mock COB RL prediction - cob_prediction = ModelPrediction( - model_name="cob_rl", - prediction_type="direction", - value=0.3, # Slightly bullish - confidence=0.85, - timestamp=datetime.now(), - metadata={'cob_imbalance': 0.1, 'liquidity_depth': 150000} - ) - - orchestrator.neural_fusion.add_prediction(cob_prediction) - logger.info("๐Ÿ”ฎ Added COB RL prediction: SLIGHT_BULLISH (0.3) with 85% confidence") - - # Test 4: Create market context - market_context = MarketContext( - symbol='ETH/USDT', - current_price=2441.50, - price_change_1m=0.002, # 0.2% up in 1m - price_change_5m=0.008, # 0.8% up in 5m - volume_ratio=1.2, # 20% above average volume - volatility=0.015, # 1.5% volatility - timestamp=datetime.now() - ) - - logger.info(f"๐Ÿ“Š Market Context: ETH/USDT at ${market_context.current_price}") - logger.info(f" ๐Ÿ“ˆ Price changes: 1m: {market_context.price_change_1m:.3f}, 5m: {market_context.price_change_5m:.3f}") - logger.info(f" ๐Ÿ“Š Volume ratio: {market_context.volume_ratio:.2f}, Volatility: {market_context.volatility:.3f}") - - # Test 5: Make NN decision - fusion_decision = orchestrator.neural_fusion.make_decision( - symbol='ETH/USDT', - market_context=market_context, - min_confidence=0.25 - ) - - if fusion_decision: - logger.info("๐Ÿง  === NN DECISION RESULT ===") - logger.info(f" Action: {fusion_decision.action}") - logger.info(f" Confidence: {fusion_decision.confidence:.3f}") - logger.info(f" Expected Return: {fusion_decision.expected_return:.3f}") - logger.info(f" Risk Score: {fusion_decision.risk_score:.3f}") - logger.info(f" Position Size: {fusion_decision.position_size:.4f} ETH") - logger.info(f" Reasoning: {fusion_decision.reasoning}") - logger.info(" Model Contributions:") - for model, contribution in fusion_decision.model_contributions.items(): - logger.info(f" - {model}: {contribution:.1%}") - else: - logger.warning("โŒ No NN decision generated") - - # Test 6: Test coordinated decisions - logger.info("\n๐ŸŽฏ Testing coordinated NN decisions...") - decisions = await orchestrator.make_coordinated_decisions() - - if decisions: - logger.info(f"โœ… Generated {len(decisions)} NN-driven trading decisions:") - for i, decision in enumerate(decisions): - logger.info(f" Decision {i+1}: {decision.symbol} {decision.action} " - f"({decision.confidence:.3f} confidence, " - f"{decision.quantity:.4f} size)") - if hasattr(decision, 'metadata') and decision.metadata: - if decision.metadata.get('nn_driven'): - logger.info(f" ๐Ÿง  NN-DRIVEN: {decision.metadata.get('reasoning', 'No reasoning')}") - else: - logger.info("โ„น๏ธ No trading decisions generated (insufficient confidence)") - - # Test 7: Check NN system status - nn_status = orchestrator.neural_fusion.get_status() - logger.info("\n๐Ÿ“Š NN System Status:") - logger.info(f" Device: {nn_status['device']}") - logger.info(f" Training Mode: {nn_status['training_mode']}") - logger.info(f" Registered Models: {nn_status['registered_models']}") - logger.info(f" Recent Predictions: {nn_status['recent_predictions']}") - logger.info(f" Model Parameters: {nn_status['model_parameters']:,}") - - # Test 8: Demonstrate different confidence scenarios - logger.info("\n๐Ÿ”ฌ Testing different confidence scenarios...") - - # Low confidence scenario - low_conf_prediction = ModelPrediction( - model_name="williams_cnn", - prediction_type="direction", - value=0.1, # Weak signal - confidence=0.2, # Low confidence - timestamp=datetime.now() - ) - - orchestrator.neural_fusion.add_prediction(low_conf_prediction) - low_conf_decision = orchestrator.neural_fusion.make_decision( - symbol='ETH/USDT', - market_context=market_context, - min_confidence=0.25 - ) - - if low_conf_decision: - logger.info(f" Low confidence result: {low_conf_decision.action} (should be HOLD)") - else: - logger.info(" โœ… Low confidence correctly resulted in no decision") - - # High confidence scenario - high_conf_prediction = ModelPrediction( - model_name="williams_cnn", - prediction_type="direction", - value=0.8, # Strong signal - confidence=0.95, # Very high confidence - timestamp=datetime.now() - ) - - orchestrator.neural_fusion.add_prediction(high_conf_prediction) - high_conf_decision = orchestrator.neural_fusion.make_decision( - symbol='ETH/USDT', - market_context=market_context, - min_confidence=0.25 - ) - - if high_conf_decision: - logger.info(f" High confidence result: {high_conf_decision.action} " - f"(conf: {high_conf_decision.confidence:.3f}, " - f"size: {high_conf_decision.position_size:.4f})") - - logger.info("\nโœ… NN-DRIVEN TRADING SYSTEM TEST COMPLETE") - logger.info("๐ŸŽฏ Key Benefits Demonstrated:") - logger.info(" 1. Multiple NN models provide predictions") - logger.info(" 2. Central NN fusion makes final decisions") - logger.info(" 3. Market context influences decisions") - logger.info(" 4. Confidence thresholds prevent bad trades") - logger.info(" 5. Position sizing based on NN outputs") - logger.info(" 6. Clear reasoning for every decision") - logger.info(" 7. Model contribution tracking") - - except Exception as e: - logger.error(f"Error in NN-driven system test: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - asyncio.run(test_nn_driven_system()) \ No newline at end of file diff --git a/tests/test_pivot_normalization_system.py b/tests/test_pivot_normalization_system.py deleted file mode 100644 index e47f613..0000000 --- a/tests/test_pivot_normalization_system.py +++ /dev/null @@ -1,305 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Pivot-Based Normalization System - -This script tests the comprehensive pivot-based normalization system: -1. Monthly 1s data collection with pagination -2. Williams Market Structure pivot analysis -3. Pivot bounds extraction and caching -4. Pivot-based feature normalization -5. Integration with model training pipeline -""" - -import sys -import os -import logging -from datetime import datetime, timedelta - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.data_provider import DataProvider -from core.config import get_config - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_pivot_normalization_system(): - """Test the complete pivot-based normalization system""" - - print("="*80) - print("TESTING PIVOT-BASED NORMALIZATION SYSTEM") - print("="*80) - - # Initialize data provider - symbols = ['ETH/USDT'] # Test with ETH only - timeframes = ['1s'] - - logger.info("Initializing DataProvider with pivot-based normalization...") - data_provider = DataProvider(symbols=symbols, timeframes=timeframes) - - # Test 1: Monthly Data Collection - print("\n" + "="*60) - print("TEST 1: MONTHLY 1S DATA COLLECTION") - print("="*60) - - symbol = 'ETH/USDT' - - try: - # This will trigger monthly data collection and pivot analysis - logger.info(f"Testing monthly data collection for {symbol}...") - monthly_data = data_provider._collect_monthly_1m_data(symbol) - - if monthly_data is not None: - print(f"โœ… Monthly data collection SUCCESS") - print(f" ๐Ÿ“Š Collected {len(monthly_data):,} 1m candles") - print(f" ๐Ÿ“… Period: {monthly_data['timestamp'].min()} to {monthly_data['timestamp'].max()}") - print(f" ๐Ÿ’ฐ Price range: ${monthly_data['low'].min():.2f} - ${monthly_data['high'].max():.2f}") - print(f" ๐Ÿ“ˆ Volume range: {monthly_data['volume'].min():.2f} - {monthly_data['volume'].max():.2f}") - else: - print("โŒ Monthly data collection FAILED") - return False - - except Exception as e: - print(f"โŒ Monthly data collection ERROR: {e}") - return False - - # Test 2: Pivot Bounds Extraction - print("\n" + "="*60) - print("TEST 2: PIVOT BOUNDS EXTRACTION") - print("="*60) - - try: - logger.info("Testing pivot bounds extraction...") - bounds = data_provider._extract_pivot_bounds_from_monthly_data(symbol, monthly_data) - - if bounds is not None: - print(f"โœ… Pivot bounds extraction SUCCESS") - print(f" ๐Ÿ’ฐ Price bounds: ${bounds.price_min:.2f} - ${bounds.price_max:.2f}") - print(f" ๐Ÿ“Š Volume bounds: {bounds.volume_min:.2f} - {bounds.volume_max:.2f}") - print(f" ๐Ÿ”ธ Support levels: {len(bounds.pivot_support_levels)}") - print(f" ๐Ÿ”น Resistance levels: {len(bounds.pivot_resistance_levels)}") - print(f" ๐Ÿ“ˆ Candles analyzed: {bounds.total_candles_analyzed:,}") - print(f" โฐ Created: {bounds.created_timestamp}") - - # Store bounds for next tests - data_provider.pivot_bounds[symbol] = bounds - else: - print("โŒ Pivot bounds extraction FAILED") - return False - - except Exception as e: - print(f"โŒ Pivot bounds extraction ERROR: {e}") - return False - - # Test 3: Pivot Context Features - print("\n" + "="*60) - print("TEST 3: PIVOT CONTEXT FEATURES") - print("="*60) - - try: - logger.info("Testing pivot context features...") - - # Get recent data for testing - recent_data = data_provider.get_historical_data(symbol, '1m', limit=100) - - if recent_data is not None and not recent_data.empty: - # Add pivot context features - with_pivot_features = data_provider._add_pivot_context_features(recent_data, symbol) - - # Check if pivot features were added - pivot_features = [col for col in with_pivot_features.columns if 'pivot' in col] - - if pivot_features: - print(f"โœ… Pivot context features SUCCESS") - print(f" ๐ŸŽฏ Added features: {pivot_features}") - - # Show sample values - latest_row = with_pivot_features.iloc[-1] - print(f" ๐Ÿ“Š Latest values:") - for feature in pivot_features: - print(f" {feature}: {latest_row[feature]:.4f}") - else: - print("โŒ No pivot context features added") - return False - else: - print("โŒ Could not get recent data for testing") - return False - - except Exception as e: - print(f"โŒ Pivot context features ERROR: {e}") - return False - - # Test 4: Pivot-Based Normalization - print("\n" + "="*60) - print("TEST 4: PIVOT-BASED NORMALIZATION") - print("="*60) - - try: - logger.info("Testing pivot-based normalization...") - - # Get data with technical indicators - data_with_indicators = data_provider.get_historical_data(symbol, '1m', limit=50) - - if data_with_indicators is not None and not data_with_indicators.empty: - # Test traditional vs pivot normalization - traditional_norm = data_provider._normalize_features(data_with_indicators.tail(10)) - pivot_norm = data_provider._normalize_features(data_with_indicators.tail(10), symbol=symbol) - - print(f"โœ… Pivot-based normalization SUCCESS") - print(f" ๐Ÿ“Š Traditional normalization shape: {traditional_norm.shape}") - print(f" ๐ŸŽฏ Pivot normalization shape: {pivot_norm.shape}") - - # Compare price normalization - if 'close' in pivot_norm.columns: - trad_close_range = traditional_norm['close'].max() - traditional_norm['close'].min() - pivot_close_range = pivot_norm['close'].max() - pivot_norm['close'].min() - - print(f" ๐Ÿ’ฐ Traditional close range: {trad_close_range:.6f}") - print(f" ๐ŸŽฏ Pivot close range: {pivot_close_range:.6f}") - - # Pivot normalization should be better bounded - if 0 <= pivot_norm['close'].min() and pivot_norm['close'].max() <= 1: - print(f" โœ… Pivot normalization properly bounded [0,1]") - else: - print(f" โš ๏ธ Pivot normalization outside [0,1] bounds") - else: - print("โŒ Could not get data for normalization testing") - return False - - except Exception as e: - print(f"โŒ Pivot-based normalization ERROR: {e}") - return False - - # Test 5: Feature Matrix with Pivot Normalization - print("\n" + "="*60) - print("TEST 5: FEATURE MATRIX WITH PIVOT NORMALIZATION") - print("="*60) - - try: - logger.info("Testing feature matrix with pivot normalization...") - - # Create feature matrix using pivot normalization - feature_matrix = data_provider.get_feature_matrix(symbol, timeframes=['1m'], window_size=20) - - if feature_matrix is not None: - print(f"โœ… Feature matrix with pivot normalization SUCCESS") - print(f" ๐Ÿ“Š Matrix shape: {feature_matrix.shape}") - print(f" ๐ŸŽฏ Data range: [{feature_matrix.min():.4f}, {feature_matrix.max():.4f}]") - print(f" ๐Ÿ“ˆ Mean: {feature_matrix.mean():.4f}") - print(f" ๐Ÿ“Š Std: {feature_matrix.std():.4f}") - - # Check for proper normalization - if feature_matrix.min() >= -5 and feature_matrix.max() <= 5: # Reasonable bounds - print(f" โœ… Feature matrix reasonably bounded") - else: - print(f" โš ๏ธ Feature matrix may have extreme values") - else: - print("โŒ Feature matrix creation FAILED") - return False - - except Exception as e: - print(f"โŒ Feature matrix ERROR: {e}") - return False - - # Test 6: Caching System - print("\n" + "="*60) - print("TEST 6: CACHING SYSTEM") - print("="*60) - - try: - logger.info("Testing caching system...") - - # Test pivot bounds caching - original_bounds = data_provider.pivot_bounds[symbol] - data_provider._save_pivot_bounds_to_cache(symbol, original_bounds) - - # Clear from memory and reload - del data_provider.pivot_bounds[symbol] - loaded_bounds = data_provider._load_pivot_bounds_from_cache(symbol) - - if loaded_bounds is not None: - print(f"โœ… Pivot bounds caching SUCCESS") - print(f" ๐Ÿ’พ Original price range: ${original_bounds.price_min:.2f} - ${original_bounds.price_max:.2f}") - print(f" ๐Ÿ’พ Loaded price range: ${loaded_bounds.price_min:.2f} - ${loaded_bounds.price_max:.2f}") - - # Restore bounds - data_provider.pivot_bounds[symbol] = loaded_bounds - else: - print("โŒ Pivot bounds caching FAILED") - return False - - except Exception as e: - print(f"โŒ Caching system ERROR: {e}") - return False - - # Test 7: Public API Methods - print("\n" + "="*60) - print("TEST 7: PUBLIC API METHODS") - print("="*60) - - try: - logger.info("Testing public API methods...") - - # Test get_pivot_bounds - api_bounds = data_provider.get_pivot_bounds(symbol) - if api_bounds is not None: - print(f"โœ… get_pivot_bounds() SUCCESS") - print(f" ๐Ÿ“Š Returned bounds for {api_bounds.symbol}") - - # Test get_pivot_normalized_features - test_data = data_provider.get_historical_data(symbol, '1m', limit=10) - if test_data is not None: - normalized_data = data_provider.get_pivot_normalized_features(symbol, test_data) - if normalized_data is not None: - print(f"โœ… get_pivot_normalized_features() SUCCESS") - print(f" ๐Ÿ“Š Normalized data shape: {normalized_data.shape}") - else: - print("โŒ get_pivot_normalized_features() FAILED") - return False - - except Exception as e: - print(f"โŒ Public API methods ERROR: {e}") - return False - - # Final Summary - print("\n" + "="*80) - print("๐ŸŽ‰ PIVOT-BASED NORMALIZATION SYSTEM TEST COMPLETE") - print("="*80) - print("โœ… All tests PASSED successfully!") - print("\n๐Ÿ“‹ System Features Verified:") - print(" โœ… Monthly 1s data collection with pagination") - print(" โœ… Williams Market Structure pivot analysis") - print(" โœ… Pivot bounds extraction and validation") - print(" โœ… Pivot context features generation") - print(" โœ… Pivot-based feature normalization") - print(" โœ… Feature matrix creation with pivot bounds") - print(" โœ… Comprehensive caching system") - print(" โœ… Public API methods") - - print(f"\n๐ŸŽฏ Ready for model training with pivot-normalized features!") - return True - -if __name__ == "__main__": - try: - success = test_pivot_normalization_system() - - if success: - print("\n๐Ÿš€ Pivot-based normalization system ready for production!") - sys.exit(0) - else: - print("\nโŒ Pivot-based normalization system has issues!") - sys.exit(1) - - except KeyboardInterrupt: - print("\nโน๏ธ Test interrupted by user") - sys.exit(1) - except Exception as e: - print(f"\n๐Ÿ’ฅ Unexpected error: {e}") - import traceback - traceback.print_exc() - sys.exit(1) \ No newline at end of file diff --git a/tests/test_pnl_tracking.py b/tests/test_pnl_tracking.py deleted file mode 100644 index 52176a3..0000000 --- a/tests/test_pnl_tracking.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -""" -Test PnL Tracking System - -This script demonstrates the ultra-fast scalping PnL tracking system -""" - -import time -import logging -from run_scalping_dashboard import UltraFastScalpingRunner - -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - -def test_pnl_tracking(): - """Test the PnL tracking system""" - print("๐Ÿ”ฅ TESTING ULTRA-FAST SCALPING PnL TRACKING ๐Ÿ”ฅ") - print("="*60) - - # Create runner - runner = UltraFastScalpingRunner() - - print(f"๐Ÿ’ฐ Starting Balance: ${runner.balance:.2f}") - print(f"๐Ÿ“Š Leverage: {runner.leverage}x") - print(f"๐Ÿ’ณ Trading Fee: {runner.trading_fee*100:.3f}% per trade") - print("โšก Starting simulation for 30 seconds...") - print("="*60) - - # Start simulation - runner.start_ultra_fast_simulation() - - try: - # Run for 30 seconds - time.sleep(30) - except KeyboardInterrupt: - print("\n๐Ÿ›‘ Stopping simulation...") - - # Stop simulation - runner.running = False - - # Wait for threads to finish - if runner.simulation_thread: - runner.simulation_thread.join(timeout=2) - if runner.exit_monitor_thread: - runner.exit_monitor_thread.join(timeout=2) - - # Print final results - print("\n" + "="*60) - print("๐Ÿ’ผ FINAL PnL TRACKING RESULTS:") - print("="*60) - print(f"๐Ÿ“Š Total Trades: {len(runner.closed_trades)}") - print(f"๐ŸŽฏ Total PnL: ${runner.total_pnl:+.2f}") - print(f"๐Ÿ’ณ Total Fees: ${runner.total_fees:.2f}") - print(f"๐ŸŸข Wins: {runner.win_count} | ๐Ÿ”ด Losses: {runner.loss_count}") - if runner.win_count + runner.loss_count > 0: - win_rate = runner.win_count / (runner.win_count + runner.loss_count) - print(f"๐Ÿ“ˆ Win Rate: {win_rate*100:.1f}%") - print(f"๐Ÿ’ฐ Starting Balance: ${runner.balance:.2f}") - print(f"๐Ÿ’ฐ Final Balance: ${runner.balance + runner.total_pnl:.2f}") - if runner.balance > 0: - return_pct = ((runner.balance + runner.total_pnl) / runner.balance - 1) * 100 - print(f"๐Ÿ“Š Return: {return_pct:+.2f}%") - print(f"๐Ÿ“‹ Open Positions: {len(runner.open_positions)}") - print("="*60) - - # Show sample of closed trades - if runner.closed_trades: - print("\n๐Ÿ“ˆ SAMPLE CLOSED TRADES:") - print("-" * 40) - for i, trade in enumerate(runner.closed_trades[-5:]): # Last 5 trades - duration = (trade.exit_time - trade.entry_time).total_seconds() - pnl_color = "๐ŸŸข" if trade.pnl > 0 else "๐Ÿ”ด" - print(f"{pnl_color} Trade #{trade.trade_id}: {trade.action} {trade.symbol}") - print(f" Entry: ${trade.entry_price:.2f} โ†’ Exit: ${trade.exit_price:.2f}") - print(f" Duration: {duration:.1f}s | PnL: ${trade.pnl:+.2f}") - - print("\nโœ… PnL Tracking Test Complete!") - -if __name__ == "__main__": - test_pnl_tracking() \ No newline at end of file diff --git a/tests/test_pnl_tracking_enhanced.py b/tests/test_pnl_tracking_enhanced.py deleted file mode 100644 index cc3f886..0000000 --- a/tests/test_pnl_tracking_enhanced.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for enhanced PnL tracking with position flipping and color coding -""" - -import sys -import logging -from datetime import datetime, timezone - -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_enhanced_pnl_tracking(): - """Test the enhanced PnL tracking with position flipping""" - try: - print("="*60) - print("TESTING ENHANCED PnL TRACKING & POSITION COLOR CODING") - print("="*60) - - # Import dashboard - from web.clean_dashboard import CleanTradingDashboard as TradingDashboard - - # Create dashboard instance - dashboard = TradingDashboard() - - print(f"โœ“ Dashboard created") - print(f"โœ“ Initial position: {dashboard.current_position}") - print(f"โœ“ Initial realized PnL: ${dashboard.total_realized_pnl:.2f}") - print(f"โœ“ Initial session trades: {len(dashboard.session_trades)}") - - # Test sequence of trades with position flipping - test_trades = [ - {'action': 'BUY', 'price': 3000.0, 'size': 0.1, 'confidence': 0.75}, # Open LONG - {'action': 'SELL', 'price': 3050.0, 'size': 0.1, 'confidence': 0.80}, # Close LONG (+$5 profit) - {'action': 'SELL', 'price': 3040.0, 'size': 0.1, 'confidence': 0.70}, # Open SHORT - {'action': 'BUY', 'price': 3020.0, 'size': 0.1, 'confidence': 0.85}, # Close SHORT (+$2 profit) & flip to LONG - {'action': 'SELL', 'price': 3010.0, 'size': 0.1, 'confidence': 0.65}, # Close LONG (-$1 loss) - ] - - print("\n" + "="*60) - print("EXECUTING TEST TRADE SEQUENCE:") - print("="*60) - - for i, trade in enumerate(test_trades, 1): - print(f"\n--- Trade {i}: {trade['action']} @ ${trade['price']:.2f} ---") - - # Add required fields - trade['symbol'] = 'ETH/USDT' - trade['timestamp'] = datetime.now(timezone.utc) - trade['reason'] = f'Test trade {i}' - - # Process the trade - dashboard._process_trading_decision(trade) - - # Show results - print(f"Current position: {dashboard.current_position}") - print(f"Realized PnL: ${dashboard.total_realized_pnl:.2f}") - print(f"Total trades: {len(dashboard.session_trades)}") - print(f"Recent decisions: {len(dashboard.recent_decisions)}") - - # Test unrealized PnL calculation - if dashboard.current_position: - current_price = trade['price'] + 5.0 # Simulate price movement - unrealized_pnl = dashboard._calculate_unrealized_pnl(current_price) - print(f"Unrealized PnL @ ${current_price:.2f}: ${unrealized_pnl:.2f}") - - print("\n" + "="*60) - print("FINAL RESULTS:") - print("="*60) - print(f"โœ“ Total realized PnL: ${dashboard.total_realized_pnl:.2f}") - print(f"โœ“ Total fees paid: ${dashboard.total_fees:.2f}") - print(f"โœ“ Total trades executed: {len(dashboard.session_trades)}") - print(f"โœ“ Final position: {dashboard.current_position}") - - # Test session performance calculation - print("\n" + "="*60) - print("SESSION PERFORMANCE TEST:") - print("="*60) - - try: - session_perf = dashboard._create_session_performance() - print(f"โœ“ Session performance component created successfully") - print(f"โœ“ Performance items count: {len(session_perf)}") - except Exception as e: - print(f"โŒ Session performance error: {e}") - - # Test decisions list with PnL info - print("\n" + "="*60) - print("DECISIONS LIST WITH PnL TEST:") - print("="*60) - - try: - decisions_list = dashboard._create_decisions_list() - print(f"โœ“ Decisions list created successfully") - print(f"โœ“ Decisions items count: {len(decisions_list)}") - - # Check for PnL information in closed trades - closed_trades = [t for t in dashboard.session_trades if 'pnl' in t] - print(f"โœ“ Closed trades with PnL: {len(closed_trades)}") - - for trade in closed_trades: - action = trade.get('position_action', 'UNKNOWN') - pnl = trade.get('pnl', 0) - entry_price = trade.get('entry_price', 0) - exit_price = trade.get('price', 0) - print(f" - {action}: Entry ${entry_price:.2f} -> Exit ${exit_price:.2f} = PnL ${pnl:.2f}") - - except Exception as e: - print(f"โŒ Decisions list error: {e}") - - print("\n" + "="*60) - print("ENHANCED FEATURES VERIFIED:") - print("="*60) - print("โœ“ Position flipping (LONG -> SHORT -> LONG)") - print("โœ“ PnL calculation for closed trades") - print("โœ“ Color coding for positions based on side and P&L") - print("โœ“ Entry/exit price tracking") - print("โœ“ Real-time unrealized PnL calculation") - print("โœ“ ASCII indicators (no Unicode for Windows compatibility)") - print("โœ“ Enhanced trade logging with PnL information") - print("โœ“ Session performance metrics with PnL breakdown") - - return True - - except Exception as e: - print(f"โŒ Error testing enhanced PnL tracking: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - success = test_enhanced_pnl_tracking() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/test_realtime_cob.py b/tests/test_realtime_cob.py deleted file mode 100644 index aedf212..0000000 --- a/tests/test_realtime_cob.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for real-time COB functionality -""" - -import asyncio -import aiohttp -import json -import time -from datetime import datetime - -async def test_realtime_cob(): - """Test real-time COB data streaming""" - - # Test API endpoints - base_url = "http://localhost:8053" - - async with aiohttp.ClientSession() as session: - print("Testing COB Dashboard API endpoints...") - - # Test symbols endpoint - try: - async with session.get(f"{base_url}/api/symbols") as response: - if response.status == 200: - data = await response.json() - print(f"โœ“ Symbols: {data}") - else: - print(f"โœ— Symbols endpoint failed: {response.status}") - except Exception as e: - print(f"โœ— Error testing symbols endpoint: {e}") - - # Test real-time stats for BTC/USDT - try: - async with session.get(f"{base_url}/api/realtime/BTC/USDT") as response: - if response.status == 200: - data = await response.json() - print(f"โœ“ Real-time stats for BTC/USDT:") - print(f" Current mid price: {data.get('current', {}).get('mid_price', 'N/A')}") - print(f" 1s window updates: {data.get('1s_window', {}).get('update_count', 'N/A')}") - print(f" 5s window updates: {data.get('5s_window', {}).get('update_count', 'N/A')}") - else: - print(f"โœ— Real-time stats endpoint failed: {response.status}") - error_data = await response.text() - print(f" Error: {error_data}") - except Exception as e: - print(f"โœ— Error testing real-time stats endpoint: {e}") - - # Test WebSocket connection - print("\nTesting WebSocket connection...") - try: - async with session.ws_connect(f"{base_url.replace('http', 'ws')}/ws") as ws: - print("โœ“ WebSocket connected") - - # Wait for some data - message_count = 0 - start_time = time.time() - - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - data = json.loads(msg.data) - message_count += 1 - - if data.get('type') == 'cob_update': - symbol = data.get('data', {}).get('stats', {}).get('symbol', 'Unknown') - mid_price = data.get('data', {}).get('stats', {}).get('mid_price', 0) - print(f"โœ“ Received COB update for {symbol}: ${mid_price:.2f}") - - # Check for real-time stats - if 'realtime_1s' in data.get('data', {}).get('stats', {}): - print(f" โœ“ Real-time 1s stats available") - if 'realtime_5s' in data.get('data', {}).get('stats', {}): - print(f" โœ“ Real-time 5s stats available") - - # Stop after 5 messages or 10 seconds - if message_count >= 5 or (time.time() - start_time) > 10: - break - elif msg.type == aiohttp.WSMsgType.ERROR: - print(f"โœ— WebSocket error: {ws.exception()}") - break - - print(f"โœ“ Received {message_count} WebSocket messages") - - except Exception as e: - print(f"โœ— WebSocket connection failed: {e}") - -if __name__ == "__main__": - print("Testing Real-time COB Dashboard") - print("=" * 40) - - asyncio.run(test_realtime_cob()) - - print("\nTest completed!") \ No newline at end of file diff --git a/tests/test_realtime_rl_cob_trader.py b/tests/test_realtime_rl_cob_trader.py deleted file mode 100644 index 80fc502..0000000 --- a/tests/test_realtime_rl_cob_trader.py +++ /dev/null @@ -1,555 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Script for Real-time RL COB Trader - -This script tests the real-time reinforcement learning system to ensure: -1. Proper model initialization and parameter count (~1B parameters) -2. COB data integration and feature extraction -3. Real-time inference pipeline -4. Signal accumulation and consensus -5. Training loop functionality -6. Trade execution integration - -Run this before deploying the live system. -""" - -import asyncio -import logging -import numpy as np -import torch -import time -import json -from datetime import datetime -from typing import Dict, Any - -# Local imports -from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, MassiveRLNetwork, PredictionResult -from core.trading_executor import TradingExecutor - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) - -logger = logging.getLogger(__name__) - -class RealtimeRLTester: - """ - Comprehensive tester for Real-time RL COB Trader - """ - - def __init__(self): - self.test_results = {} - self.trader = None - self.trading_executor = None - - async def run_all_tests(self): - """Run all tests and generate report""" - logger.info("=" * 60) - logger.info("REAL-TIME RL COB TRADER TESTING SUITE") - logger.info("=" * 60) - - tests = [ - self.test_model_initialization, - self.test_model_parameter_count, - self.test_feature_extraction, - self.test_inference_performance, - self.test_signal_accumulation, - self.test_training_pipeline, - self.test_trading_integration, - self.test_performance_monitoring - ] - - for test in tests: - try: - await test() - except Exception as e: - logger.error(f"Test {test.__name__} failed: {e}") - self.test_results[test.__name__] = {'status': 'FAILED', 'error': str(e)} - - await self.generate_test_report() - - async def test_model_initialization(self): - """Test model initialization and architecture""" - logger.info("๐Ÿง  Testing Model Initialization...") - - try: - # Test model creation - model = MassiveRLNetwork(input_size=2000, hidden_size=4096, num_layers=12) - - # Check if CUDA is available - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = model.to(device) - - # Test forward pass - batch_size = 4 - test_input = torch.randn(batch_size, 2000).to(device) - - with torch.no_grad(): - outputs = model(test_input) - - # Verify outputs - assert 'price_logits' in outputs - assert 'value' in outputs - assert 'confidence' in outputs - assert 'features' in outputs - - assert outputs['price_logits'].shape == (batch_size, 3) # DOWN, SIDEWAYS, UP - assert outputs['value'].shape == (batch_size, 1) - assert outputs['confidence'].shape == (batch_size, 1) - - self.test_results['test_model_initialization'] = { - 'status': 'PASSED', - 'device': str(device), - 'output_shapes': {k: list(v.shape) for k, v in outputs.items()} - } - - logger.info("โœ… Model initialization test PASSED") - - except Exception as e: - self.test_results['test_model_initialization'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_model_parameter_count(self): - """Test that model has approximately 400M parameters""" - logger.info("๐Ÿ”ข Testing Model Parameter Count...") - - try: - model = MassiveRLNetwork(input_size=2000, hidden_size=2048, num_layers=8) - - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - logger.info(f"Total parameters: {total_params:,}") - logger.info(f"Trainable parameters: {trainable_params:,}") - - # Check if parameters are approximately 400M (350M - 450M range) - target_400m = total_params >= 350_000_000 and total_params <= 450_000_000 - - self.test_results['test_model_parameter_count'] = { - 'status': 'PASSED' if target_400m else 'WARNING', - 'total_parameters': total_params, - 'trainable_parameters': trainable_params, - 'parameter_size_gb': (total_params * 4) / (1024**3), # 4 bytes per float32 - 'is_optimized': target_400m, # Around 400M parameters for faster startup - 'target_range': '350M - 450M parameters' - } - - logger.info(f"โœ… Model has {total_params:,} parameters ({total_params/1e6:.0f}M)") - if target_400m: - logger.info("โœ… Parameter count within 400M target range for fast startup") - else: - logger.warning(f"โš ๏ธ Parameter count outside 400M target range: {total_params/1e6:.0f}M") - - except Exception as e: - self.test_results['test_model_parameter_count'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_feature_extraction(self): - """Test feature extraction from COB data""" - logger.info("๐Ÿ” Testing Feature Extraction...") - - try: - # Initialize trader - self.trading_executor = TradingExecutor(simulation_mode=True) - self.trader = RealtimeRLCOBTrader( - symbols=['BTC/USDT'], - trading_executor=self.trading_executor, - inference_interval_ms=1000 # Slower for testing - ) - - # Create mock COB data - mock_cob_data = { - 'state': np.random.randn(1500), # Mock state features - 'timestamp': datetime.now(), - 'type': 'cob_state' - } - - # Test feature extraction - features = self.trader._extract_features('BTC/USDT', mock_cob_data) - - assert features is not None - assert len(features) == 2000 # Target feature size - assert features.dtype == np.float32 - assert not np.any(np.isnan(features)) - assert not np.any(np.isinf(features)) - - # Test normalization - assert np.abs(np.mean(features)) < 1.0 # Roughly normalized - assert np.std(features) < 10.0 # Not too spread out - - self.test_results['test_feature_extraction'] = { - 'status': 'PASSED', - 'feature_size': len(features), - 'feature_range': [float(np.min(features)), float(np.max(features))], - 'feature_stats': { - 'mean': float(np.mean(features)), - 'std': float(np.std(features)), - 'median': float(np.median(features)) - } - } - - logger.info("โœ… Feature extraction test PASSED") - - except Exception as e: - self.test_results['test_feature_extraction'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_inference_performance(self): - """Test inference speed and quality""" - logger.info("โšก Testing Inference Performance...") - - try: - if not self.trader: - self.trading_executor = TradingExecutor(simulation_mode=True) - self.trader = RealtimeRLCOBTrader( - symbols=['BTC/USDT'], - trading_executor=self.trading_executor - ) - - # Test multiple inferences - num_tests = 10 - inference_times = [] - - for i in range(num_tests): - # Create test features - test_features = np.random.randn(2000).astype(np.float32) - test_features = self.trader._normalize_features(test_features) - - # Time inference - start_time = time.time() - prediction = self.trader._predict('BTC/USDT', test_features) - inference_time = (time.time() - start_time) * 1000 - - inference_times.append(inference_time) - - # Verify prediction structure - assert 'direction' in prediction - assert 'confidence' in prediction - assert 'change' in prediction - assert 'value' in prediction - - assert 0 <= prediction['direction'] <= 2 - assert 0.0 <= prediction['confidence'] <= 1.0 - assert isinstance(prediction['change'], float) - assert isinstance(prediction['value'], float) - - avg_inference_time = np.mean(inference_times) - max_inference_time = np.max(inference_times) - - # Check if inference is fast enough (target: <50ms per inference) - inference_target_ms = 50.0 - - self.test_results['test_inference_performance'] = { - 'status': 'PASSED' if avg_inference_time < inference_target_ms else 'WARNING', - 'average_inference_time_ms': float(avg_inference_time), - 'max_inference_time_ms': float(max_inference_time), - 'target_time_ms': inference_target_ms, - 'meets_target': avg_inference_time < inference_target_ms, - 'inferences_per_second': 1000.0 / avg_inference_time - } - - logger.info(f"โœ… Average inference time: {avg_inference_time:.2f}ms") - logger.info(f"โœ… Max inference time: {max_inference_time:.2f}ms") - logger.info(f"โœ… Inferences per second: {1000.0/avg_inference_time:.1f}") - - except Exception as e: - self.test_results['test_inference_performance'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_signal_accumulation(self): - """Test signal accumulation and consensus logic""" - logger.info("๐ŸŽฏ Testing Signal Accumulation...") - - try: - if not self.trader: - self.trading_executor = TradingExecutor(simulation_mode=True) - self.trader = RealtimeRLCOBTrader( - symbols=['BTC/USDT'], - trading_executor=self.trading_executor, - required_confident_predictions=3 - ) - - symbol = 'BTC/USDT' - accumulator = self.trader.signal_accumulators[symbol] - - # Test adding signals - test_predictions = [] - for i in range(5): - prediction = PredictionResult( - timestamp=datetime.now(), - symbol=symbol, - predicted_direction=2, # UP - confidence=0.8, - predicted_change=0.001, - features=np.random.randn(2000).astype(np.float32) - ) - test_predictions.append(prediction) - self.trader._add_signal(symbol, prediction) - - # Check accumulator state - assert len(accumulator.signals) == 5 - assert accumulator.confidence_sum == 5 * 0.8 - assert accumulator.total_predictions == 5 - - # Test consensus logic (simulate processing) - recent_signals = list(accumulator.signals)[-3:] - directions = [signal.predicted_direction for signal in recent_signals] - - # All should be direction 2 (UP) - direction_counts = {0: 0, 1: 0, 2: 0} - for direction in directions: - direction_counts[direction] += 1 - - dominant_direction = max(direction_counts, key=direction_counts.get) - consensus_count = direction_counts[dominant_direction] - - assert dominant_direction == 2 - assert consensus_count == 3 - - self.test_results['test_signal_accumulation'] = { - 'status': 'PASSED', - 'signals_added': len(accumulator.signals), - 'confidence_sum': accumulator.confidence_sum, - 'consensus_direction': dominant_direction, - 'consensus_count': consensus_count - } - - logger.info("โœ… Signal accumulation test PASSED") - - except Exception as e: - self.test_results['test_signal_accumulation'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_training_pipeline(self): - """Test training pipeline functionality""" - logger.info("๐Ÿง  Testing Training Pipeline...") - - try: - if not self.trader: - self.trading_executor = TradingExecutor(simulation_mode=True) - self.trader = RealtimeRLCOBTrader( - symbols=['BTC/USDT'], - trading_executor=self.trading_executor - ) - - symbol = 'BTC/USDT' - - # Create mock training data - test_predictions = [] - for i in range(10): - prediction = PredictionResult( - timestamp=datetime.now(), - symbol=symbol, - predicted_direction=np.random.randint(0, 3), - confidence=np.random.uniform(0.5, 1.0), - predicted_change=np.random.uniform(-0.001, 0.001), - features=np.random.randn(2000).astype(np.float32), - actual_direction=np.random.randint(0, 3), - actual_change=np.random.uniform(-0.001, 0.001), - reward=np.random.uniform(-1.0, 1.0) - ) - test_predictions.append(prediction) - - # Test training batch - loss = await self.trader._train_batch(symbol, test_predictions) - - assert isinstance(loss, float) - assert not np.isnan(loss) - assert not np.isinf(loss) - assert loss >= 0.0 # Loss should be non-negative - - self.test_results['test_training_pipeline'] = { - 'status': 'PASSED', - 'training_loss': float(loss), - 'batch_size': len(test_predictions), - 'training_successful': True - } - - logger.info(f"โœ… Training pipeline test PASSED (loss: {loss:.6f})") - - except Exception as e: - self.test_results['test_training_pipeline'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_trading_integration(self): - """Test integration with trading executor""" - logger.info("๐Ÿ’ฐ Testing Trading Integration...") - - try: - # Initialize with simulation mode - trading_executor = TradingExecutor(simulation_mode=True) - - # Test signal execution - success = trading_executor.execute_signal( - symbol='BTC/USDT', - action='BUY', - confidence=0.8, - current_price=50000.0 - ) - - # In simulation mode, this should always succeed - assert success == True - - # Check positions - positions = trading_executor.get_positions() - assert 'BTC/USDT' in positions - - # Test sell signal - success = trading_executor.execute_signal( - symbol='BTC/USDT', - action='SELL', - confidence=0.8, - current_price=50100.0 - ) - - assert success == True - - # Check trade history - trade_history = trading_executor.get_trade_history() - assert len(trade_history) > 0 - - last_trade = trade_history[-1] - assert last_trade.symbol == 'BTC/USDT' - assert last_trade.pnl != 0 # Should have some P&L - - self.test_results['test_trading_integration'] = { - 'status': 'PASSED', - 'simulation_mode': True, - 'trades_executed': len(trade_history), - 'last_trade_pnl': float(last_trade.pnl) - } - - logger.info("โœ… Trading integration test PASSED") - - except Exception as e: - self.test_results['test_trading_integration'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def test_performance_monitoring(self): - """Test performance monitoring and statistics""" - logger.info("๐Ÿ“Š Testing Performance Monitoring...") - - try: - if not self.trader: - self.trading_executor = TradingExecutor(simulation_mode=True) - self.trader = RealtimeRLCOBTrader( - symbols=['BTC/USDT', 'ETH/USDT'], - trading_executor=self.trading_executor - ) - - # Get performance stats - stats = self.trader.get_performance_stats() - - # Verify structure - assert 'symbols' in stats - assert 'training_stats' in stats - assert 'inference_stats' in stats - assert 'signal_stats' in stats - assert 'model_info' in stats - - # Check symbols - assert 'BTC/USDT' in stats['symbols'] - assert 'ETH/USDT' in stats['symbols'] - - # Check model info - for symbol in stats['symbols']: - assert symbol in stats['model_info'] - model_info = stats['model_info'][symbol] - assert 'total_parameters' in model_info - assert 'trainable_parameters' in model_info - assert model_info['total_parameters'] > 0 - - self.test_results['test_performance_monitoring'] = { - 'status': 'PASSED', - 'stats_structure_valid': True, - 'symbols_tracked': len(stats['symbols']), - 'model_info_available': len(stats['model_info']) - } - - logger.info("โœ… Performance monitoring test PASSED") - - except Exception as e: - self.test_results['test_performance_monitoring'] = {'status': 'FAILED', 'error': str(e)} - raise - - async def generate_test_report(self): - """Generate comprehensive test report""" - logger.info("=" * 60) - logger.info("REAL-TIME RL COB TRADER TEST REPORT") - logger.info("=" * 60) - - total_tests = len(self.test_results) - passed_tests = sum(1 for result in self.test_results.values() if result['status'] == 'PASSED') - failed_tests = sum(1 for result in self.test_results.values() if result['status'] == 'FAILED') - warning_tests = sum(1 for result in self.test_results.values() if result['status'] == 'WARNING') - - logger.info(f"๐Ÿ“Š Test Summary:") - logger.info(f" Total Tests: {total_tests}") - logger.info(f" โœ… Passed: {passed_tests}") - logger.info(f" โš ๏ธ Warnings: {warning_tests}") - logger.info(f" โŒ Failed: {failed_tests}") - - success_rate = (passed_tests / total_tests) * 100 if total_tests > 0 else 0 - logger.info(f" Success Rate: {success_rate:.1f}%") - - logger.info("\n๐Ÿ“‹ Detailed Results:") - for test_name, result in self.test_results.items(): - status_icon = "โœ…" if result['status'] == 'PASSED' else "โš ๏ธ" if result['status'] == 'WARNING' else "โŒ" - logger.info(f" {status_icon} {test_name}: {result['status']}") - - if result['status'] == 'FAILED': - logger.error(f" Error: {result.get('error', 'Unknown error')}") - - # System readiness assessment - logger.info("\n๐ŸŽฏ System Readiness Assessment:") - if failed_tests == 0: - if warning_tests == 0: - logger.info(" ๐ŸŸข SYSTEM READY FOR DEPLOYMENT") - logger.info(" All tests passed. The real-time RL COB trader is ready for live operation.") - else: - logger.info(" ๐ŸŸก SYSTEM READY WITH WARNINGS") - logger.info(" System is functional but some performance warnings exist.") - else: - logger.info(" ๐Ÿ”ด SYSTEM NOT READY") - logger.info(" Critical issues found. Fix errors before deployment.") - - # Save detailed report - report_data = { - 'timestamp': datetime.now().isoformat(), - 'test_summary': { - 'total_tests': total_tests, - 'passed_tests': passed_tests, - 'warning_tests': warning_tests, - 'failed_tests': failed_tests, - 'success_rate': success_rate - }, - 'test_results': self.test_results, - 'system_readiness': 'READY' if failed_tests == 0 else 'NOT_READY' - } - - report_file = f"test_reports/realtime_rl_test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - - import os - os.makedirs('test_reports', exist_ok=True) - - with open(report_file, 'w') as f: - json.dump(report_data, f, indent=2, default=str) - - logger.info(f"\n๐Ÿ“„ Detailed report saved to: {report_file}") - logger.info("=" * 60) - -async def main(): - """Main test entry point""" - logger.info("Starting Real-time RL COB Trader Test Suite...") - - tester = RealtimeRLTester() - await tester.run_all_tests() - -if __name__ == "__main__": - # Set event loop policy for Windows compatibility - if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'): - asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_realtime_tick_processor.py b/tests/test_realtime_tick_processor.py deleted file mode 100644 index 6dea384..0000000 --- a/tests/test_realtime_tick_processor.py +++ /dev/null @@ -1,279 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Real-Time Tick Processor - -This script tests the Neural Network Real-Time Tick Processing Module -to ensure it properly processes tick data with volume information and -feeds processed features to models in real-time. -""" - -import asyncio -import logging -import sys -import time -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.realtime_tick_processor import RealTimeTickProcessor, ProcessedTickFeatures, create_realtime_tick_processor -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.data_provider import DataProvider -from core.config import get_config - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -async def test_realtime_tick_processor(): - """Test the real-time tick processor functionality""" - logger.info("="*80) - logger.info("๐Ÿงช TESTING REAL-TIME TICK PROCESSOR") - logger.info("="*80) - - try: - # Test 1: Create tick processor - logger.info("\n๐Ÿ“Š TEST 1: Creating Real-Time Tick Processor") - logger.info("-" * 40) - - symbols = ['ETH/USDT', 'BTC/USDT'] - tick_processor = create_realtime_tick_processor(symbols) - - logger.info("โœ… Tick processor created successfully") - logger.info(f" Symbols: {tick_processor.symbols}") - logger.info(f" Device: {tick_processor.device}") - logger.info(f" Buffer size: {tick_processor.tick_buffer_size}") - - # Test 2: Feature subscriber - logger.info("\n๐Ÿ“ก TEST 2: Feature Subscriber Integration") - logger.info("-" * 40) - - received_features = [] - - def test_callback(symbol: str, features: ProcessedTickFeatures): - """Test callback to receive processed features""" - received_features.append((symbol, features)) - logger.info(f"Received features for {symbol}: confidence={features.confidence:.3f}") - logger.info(f" Neural features shape: {features.neural_features.shape}") - logger.info(f" Volume features shape: {features.volume_features.shape}") - logger.info(f" Price features shape: {features.price_features.shape}") - logger.info(f" Microstructure features shape: {features.microstructure_features.shape}") - - tick_processor.add_feature_subscriber(test_callback) - logger.info("โœ… Feature subscriber added") - - # Test 3: Start processing (short duration) - logger.info("\n๐Ÿš€ TEST 3: Start Real-Time Processing") - logger.info("-" * 40) - - logger.info("Starting tick processing for 30 seconds...") - await tick_processor.start_processing() - - # Let it run for 30 seconds to collect some data - start_time = time.time() - while time.time() - start_time < 30: - await asyncio.sleep(1) - - # Check stats every 5 seconds - if int(time.time() - start_time) % 5 == 0: - stats = tick_processor.get_processing_stats() - logger.info(f"Processing stats: {stats.get('tick_counts', {})}") - - if stats.get('processing_performance'): - perf = stats['processing_performance'] - logger.info(f"Performance: avg={perf['avg_time_ms']:.2f}ms, " - f"min={perf['min_time_ms']:.2f}ms, max={perf['max_time_ms']:.2f}ms") - - logger.info("โœ… Real-time processing test completed") - - # Test 4: Check received features - logger.info("\n๐Ÿ“ˆ TEST 4: Analyze Received Features") - logger.info("-" * 40) - - if received_features: - logger.info(f"โœ… Received {len(received_features)} feature sets") - - # Analyze feature quality - high_confidence_count = sum(1 for _, features in received_features if features.confidence > 0.7) - avg_confidence = sum(features.confidence for _, features in received_features) / len(received_features) - - logger.info(f" Average confidence: {avg_confidence:.3f}") - logger.info(f" High confidence features (>0.7): {high_confidence_count}") - - # Show latest features - if received_features: - symbol, latest_features = received_features[-1] - logger.info(f" Latest features for {symbol}:") - logger.info(f" Timestamp: {latest_features.timestamp}") - logger.info(f" Confidence: {latest_features.confidence:.3f}") - logger.info(f" Neural features sample: {latest_features.neural_features[:5]}") - logger.info(f" Volume features sample: {latest_features.volume_features[:3]}") - else: - logger.warning("โš ๏ธ No features received - this may be normal if markets are closed") - - # Test 5: Integration with orchestrator - logger.info("\n๐ŸŽฏ TEST 5: Integration with Enhanced Orchestrator") - logger.info("-" * 40) - - try: - config = get_config() - data_provider = DataProvider(config) - orchestrator = EnhancedTradingOrchestrator(data_provider) - - # Check if tick processor is integrated - if hasattr(orchestrator, 'tick_processor'): - logger.info("โœ… Tick processor integrated with orchestrator") - logger.info(f" Orchestrator symbols: {orchestrator.symbols}") - logger.info(f" Tick processor symbols: {orchestrator.tick_processor.symbols}") - - # Test real-time processing start - await orchestrator.start_realtime_processing() - logger.info("โœ… Orchestrator real-time processing started") - - # Brief test - await asyncio.sleep(5) - - # Get stats - tick_stats = orchestrator.get_realtime_tick_stats() - logger.info(f" Orchestrator tick stats: {tick_stats}") - - await orchestrator.stop_realtime_processing() - logger.info("โœ… Orchestrator real-time processing stopped") - else: - logger.error("โŒ Tick processor not found in orchestrator") - - except Exception as e: - logger.error(f"โŒ Orchestrator integration test failed: {e}") - - # Test 6: Stop processing - logger.info("\n๐Ÿ›‘ TEST 6: Stop Processing") - logger.info("-" * 40) - - await tick_processor.stop_processing() - logger.info("โœ… Tick processing stopped") - - # Final stats - final_stats = tick_processor.get_processing_stats() - logger.info(f"Final stats: {final_stats}") - - # Test 7: Neural Network Features - logger.info("\n๐Ÿง  TEST 7: Neural Network Feature Quality") - logger.info("-" * 40) - - if received_features: - # Analyze neural network output quality - neural_feature_sizes = [len(features.neural_features) for _, features in received_features] - confidence_scores = [features.confidence for _, features in received_features] - - logger.info(f" Neural feature dimensions: {set(neural_feature_sizes)}") - logger.info(f" Confidence range: {min(confidence_scores):.3f} - {max(confidence_scores):.3f}") - logger.info(f" Average confidence: {sum(confidence_scores)/len(confidence_scores):.3f}") - - # Check for feature consistency - if len(set(neural_feature_sizes)) == 1: - logger.info("โœ… Neural features have consistent dimensions") - else: - logger.warning("โš ๏ธ Neural feature dimensions are inconsistent") - - # Summary - logger.info("\n" + "="*80) - logger.info("๐ŸŽ‰ REAL-TIME TICK PROCESSOR TEST SUMMARY") - logger.info("="*80) - logger.info("โœ… All core tests PASSED!") - logger.info("") - logger.info("๐Ÿ“‹ VERIFIED FUNCTIONALITY:") - logger.info(" โœ“ Real-time tick data ingestion") - logger.info(" โœ“ Neural network feature extraction") - logger.info(" โœ“ Volume and microstructure analysis") - logger.info(" โœ“ Ultra-low latency processing") - logger.info(" โœ“ Feature subscriber system") - logger.info(" โœ“ Integration with orchestrator") - logger.info(" โœ“ Performance monitoring") - logger.info("") - logger.info("๐ŸŽฏ NEURAL DPS ALTERNATIVE ACTIVE:") - logger.info(" โ€ข Real-time tick processing โœ“") - logger.info(" โ€ข Volume-weighted analysis โœ“") - logger.info(" โ€ข Neural feature extraction โœ“") - logger.info(" โ€ข Sub-millisecond latency โœ“") - logger.info(" โ€ข Model integration ready โœ“") - logger.info("") - logger.info("๐Ÿš€ Your real-time tick processor is working as a Neural DPS alternative!") - logger.info("="*80) - - return True - - except Exception as e: - logger.error(f"โŒ Real-time tick processor test failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -async def test_dqn_integration(): - """Test DQN integration with real-time tick features""" - logger.info("\n๐Ÿค– TESTING DQN INTEGRATION WITH TICK FEATURES") - logger.info("-" * 50) - - try: - from NN.models.dqn_agent import DQNAgent - import numpy as np - - # Create DQN agent - state_shape = (3, 5) # 3 timeframes, 5 features - dqn = DQNAgent(state_shape=state_shape, n_actions=3) - - logger.info("โœ… DQN agent created") - logger.info(f" Tick feature weight: {dqn.tick_feature_weight}") - - # Test state enhancement - test_state = np.random.rand(3, 5) - - # Simulate tick features - mock_tick_features = { - 'neural_features': np.random.rand(64), - 'volume_features': np.random.rand(8), - 'microstructure_features': np.random.rand(4), - 'confidence': 0.85 - } - - # Update DQN with tick features - dqn.update_realtime_tick_features(mock_tick_features) - logger.info("โœ… DQN updated with mock tick features") - - # Test enhanced action selection - action = dqn.act(test_state, explore=False) - logger.info(f"โœ… DQN action with tick features: {action}") - - # Test without tick features - dqn.realtime_tick_features = None - action_without = dqn.act(test_state, explore=False) - logger.info(f"โœ… DQN action without tick features: {action_without}") - - logger.info("โœ… DQN integration test completed successfully") - - except Exception as e: - logger.error(f"โŒ DQN integration test failed: {e}") - -async def main(): - """Main test function""" - logger.info("๐Ÿš€ Starting Real-Time Tick Processor Tests...") - - # Test the tick processor - success = await test_realtime_tick_processor() - - if success: - # Test DQN integration - await test_dqn_integration() - - logger.info("\n๐ŸŽ‰ All tests passed! Your Neural DPS alternative is ready.") - logger.info("The real-time tick processor provides ultra-low latency processing") - logger.info("with volume information and neural network feature extraction.") - else: - logger.error("\n๐Ÿ’ฅ Tests failed! Please check the implementation.") - sys.exit(1) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_rl_subscriber_system.py b/tests/test_rl_subscriber_system.py deleted file mode 100644 index 0519ecb..0000000 --- a/tests/test_rl_subscriber_system.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/tests/test_sensitivity_learning.py b/tests/test_sensitivity_learning.py deleted file mode 100644 index f999afd..0000000 --- a/tests/test_sensitivity_learning.py +++ /dev/null @@ -1,372 +0,0 @@ -#!/usr/bin/env python3 -""" -Test DQN RL-based Sensitivity Learning and 300s Data Preloading - -This script tests: -1. DQN RL-based sensitivity learning from completed trades -2. 300s data preloading on first load -3. Dynamic threshold adjustment based on sensitivity levels -4. Color-coded position display integration -5. Enhanced model training status with sensitivity info - -Usage: - python test_sensitivity_learning.py -""" - -import asyncio -import logging -import time -import numpy as np -from datetime import datetime, timedelta -from core.data_provider import DataProvider -from core.enhanced_orchestrator import EnhancedTradingOrchestrator, TradingAction -from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard -from NN.models.dqn_agent import DQNAgent - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -class SensitivityLearningTester: - """Test class for sensitivity learning features""" - - def __init__(self): - self.data_provider = DataProvider() - self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) - self.dashboard = None - - async def test_300s_data_preloading(self): - """Test 300s data preloading functionality""" - logger.info("=== Testing 300s Data Preloading ===") - - # Test preloading for all symbols and timeframes - start_time = time.time() - preload_results = self.data_provider.preload_all_symbols_data(['1s', '1m', '5m', '15m', '1h']) - end_time = time.time() - - logger.info(f"Preloading completed in {end_time - start_time:.2f} seconds") - - # Verify results - total_pairs = 0 - successful_pairs = 0 - - for symbol, timeframe_results in preload_results.items(): - for timeframe, success in timeframe_results.items(): - total_pairs += 1 - if success: - successful_pairs += 1 - - # Verify data was actually loaded - data = self.data_provider.get_historical_data(symbol, timeframe, limit=50) - if data is not None and len(data) > 0: - logger.info(f"โœ… {symbol} {timeframe}: {len(data)} candles loaded") - else: - logger.warning(f"โŒ {symbol} {timeframe}: No data despite success flag") - else: - logger.warning(f"โŒ {symbol} {timeframe}: Failed to preload") - - success_rate = (successful_pairs / total_pairs) * 100 if total_pairs > 0 else 0 - logger.info(f"Preloading success rate: {success_rate:.1f}% ({successful_pairs}/{total_pairs})") - - return success_rate > 80 # Consider test passed if >80% success rate - - def test_sensitivity_learning_initialization(self): - """Test sensitivity learning system initialization""" - logger.info("=== Testing Sensitivity Learning Initialization ===") - - # Check if sensitivity learning is enabled - if hasattr(self.orchestrator, 'sensitivity_learning_enabled'): - logger.info(f"โœ… Sensitivity learning enabled: {self.orchestrator.sensitivity_learning_enabled}") - else: - logger.warning("โŒ Sensitivity learning not found in orchestrator") - return False - - # Check sensitivity levels configuration - if hasattr(self.orchestrator, 'sensitivity_levels'): - levels = self.orchestrator.sensitivity_levels - logger.info(f"โœ… Sensitivity levels configured: {len(levels)} levels") - for level, config in levels.items(): - logger.info(f" Level {level}: {config['name']} - Open: {config['open_threshold_multiplier']:.2f}, Close: {config['close_threshold_multiplier']:.2f}") - else: - logger.warning("โŒ Sensitivity levels not configured") - return False - - # Check DQN agent initialization - if hasattr(self.orchestrator, 'sensitivity_dqn_agent'): - if self.orchestrator.sensitivity_dqn_agent is not None: - logger.info("โœ… DQN agent initialized") - stats = self.orchestrator.sensitivity_dqn_agent.get_stats() - logger.info(f" Device: {stats['device']}") - logger.info(f" Memory size: {stats['memory_size']}") - logger.info(f" Epsilon: {stats['epsilon']:.3f}") - else: - logger.info("โณ DQN agent not yet initialized (will be created on first use)") - - # Check learning queues - if hasattr(self.orchestrator, 'sensitivity_learning_queue'): - logger.info(f"โœ… Sensitivity learning queue initialized: {len(self.orchestrator.sensitivity_learning_queue)} items") - - if hasattr(self.orchestrator, 'completed_trades'): - logger.info(f"โœ… Completed trades tracking initialized: {len(self.orchestrator.completed_trades)} trades") - - if hasattr(self.orchestrator, 'active_trades'): - logger.info(f"โœ… Active trades tracking initialized: {len(self.orchestrator.active_trades)} active") - - return True - - def simulate_trading_scenario(self): - """Simulate a trading scenario to test sensitivity learning""" - logger.info("=== Simulating Trading Scenario ===") - - # Simulate some trades to test the learning system - test_trades = [ - { - 'symbol': 'ETH/USDT', - 'action': 'BUY', - 'price': 2500.0, - 'confidence': 0.7, - 'timestamp': datetime.now() - timedelta(minutes=10) - }, - { - 'symbol': 'ETH/USDT', - 'action': 'SELL', - 'price': 2510.0, - 'confidence': 0.6, - 'timestamp': datetime.now() - timedelta(minutes=5) - }, - { - 'symbol': 'ETH/USDT', - 'action': 'BUY', - 'price': 2505.0, - 'confidence': 0.8, - 'timestamp': datetime.now() - timedelta(minutes=3) - }, - { - 'symbol': 'ETH/USDT', - 'action': 'SELL', - 'price': 2495.0, - 'confidence': 0.9, - 'timestamp': datetime.now() - } - ] - - # Process each trade through the orchestrator - for i, trade_data in enumerate(test_trades): - action = TradingAction( - symbol=trade_data['symbol'], - action=trade_data['action'], - quantity=0.1, - confidence=trade_data['confidence'], - price=trade_data['price'], - timestamp=trade_data['timestamp'], - reasoning={'test': f'simulated_trade_{i}'}, - timeframe_analysis=[] - ) - - # Update position tracking (this should trigger sensitivity learning) - self.orchestrator._update_position_tracking(trade_data['symbol'], action) - - logger.info(f"Processed trade {i+1}: {trade_data['action']} @ ${trade_data['price']:.2f}") - - # Check if learning cases were created - if hasattr(self.orchestrator, 'sensitivity_learning_queue'): - queue_size = len(self.orchestrator.sensitivity_learning_queue) - logger.info(f"โœ… Learning queue now has {queue_size} cases") - - if hasattr(self.orchestrator, 'completed_trades'): - completed_count = len(self.orchestrator.completed_trades) - logger.info(f"โœ… Completed trades: {completed_count}") - - return True - - def test_threshold_adjustment(self): - """Test dynamic threshold adjustment based on sensitivity""" - logger.info("=== Testing Threshold Adjustment ===") - - # Test different sensitivity levels - for level in range(5): # 0-4 sensitivity levels - if hasattr(self.orchestrator, 'current_sensitivity_level'): - self.orchestrator.current_sensitivity_level = level - - if hasattr(self.orchestrator, '_update_thresholds_from_sensitivity'): - self.orchestrator._update_thresholds_from_sensitivity() - - open_threshold = getattr(self.orchestrator, 'confidence_threshold_open', 0.6) - close_threshold = getattr(self.orchestrator, 'confidence_threshold_close', 0.25) - - logger.info(f"Level {level}: Open={open_threshold:.3f}, Close={close_threshold:.3f}") - - return True - - def test_dashboard_integration(self): - """Test dashboard integration with sensitivity learning""" - logger.info("=== Testing Dashboard Integration ===") - - try: - # Create dashboard instance - self.dashboard = RealTimeScalpingDashboard( - data_provider=self.data_provider, - orchestrator=self.orchestrator - ) - - # Test sensitivity learning info retrieval - sensitivity_info = self.dashboard._get_sensitivity_learning_info() - - logger.info("โœ… Dashboard sensitivity info:") - logger.info(f" Level: {sensitivity_info['level_name']}") - logger.info(f" Completed trades: {sensitivity_info['completed_trades']}") - logger.info(f" Learning queue: {sensitivity_info['learning_queue_size']}") - logger.info(f" Open threshold: {sensitivity_info['open_threshold']:.3f}") - logger.info(f" Close threshold: {sensitivity_info['close_threshold']:.3f}") - - return True - - except Exception as e: - logger.error(f"โŒ Dashboard integration test failed: {e}") - return False - - def test_dqn_training_simulation(self): - """Test DQN training with simulated data""" - logger.info("=== Testing DQN Training Simulation ===") - - try: - # Initialize DQN agent if not already done - if not hasattr(self.orchestrator, 'sensitivity_dqn_agent') or self.orchestrator.sensitivity_dqn_agent is None: - self.orchestrator._initialize_sensitivity_dqn() - - if self.orchestrator.sensitivity_dqn_agent is None: - logger.warning("โŒ Could not initialize DQN agent") - return False - - # Create some mock learning cases - for i in range(10): - # Create mock market state - mock_state = np.random.random(self.orchestrator.sensitivity_state_size) - action = np.random.randint(0, self.orchestrator.sensitivity_action_space) - reward = np.random.random() - 0.5 # Random reward between -0.5 and 0.5 - next_state = np.random.random(self.orchestrator.sensitivity_state_size) - done = True - - # Add to learning queue - learning_case = { - 'state': mock_state, - 'action': action, - 'reward': reward, - 'next_state': next_state, - 'done': done, - 'optimal_action': action, - 'trade_outcome': reward * 0.02, # Convert to percentage - 'trade_duration': 300 + np.random.randint(-100, 100), - 'symbol': 'ETH/USDT' - } - - self.orchestrator.sensitivity_learning_queue.append(learning_case) - - # Trigger training - initial_queue_size = len(self.orchestrator.sensitivity_learning_queue) - self.orchestrator._train_sensitivity_dqn() - - logger.info(f"โœ… DQN training completed") - logger.info(f" Initial queue size: {initial_queue_size}") - logger.info(f" Final queue size: {len(self.orchestrator.sensitivity_learning_queue)}") - - # Check agent stats - if self.orchestrator.sensitivity_dqn_agent: - stats = self.orchestrator.sensitivity_dqn_agent.get_stats() - logger.info(f" Training steps: {stats['training_step']}") - logger.info(f" Memory size: {stats['memory_size']}") - logger.info(f" Epsilon: {stats['epsilon']:.3f}") - - return True - - except Exception as e: - logger.error(f"โŒ DQN training simulation failed: {e}") - return False - - async def run_all_tests(self): - """Run all sensitivity learning tests""" - logger.info("๐Ÿš€ Starting Sensitivity Learning Test Suite") - logger.info("=" * 60) - - test_results = {} - - # Test 1: 300s Data Preloading - test_results['preloading'] = await self.test_300s_data_preloading() - - # Test 2: Sensitivity Learning Initialization - test_results['initialization'] = self.test_sensitivity_learning_initialization() - - # Test 3: Trading Scenario Simulation - test_results['trading_simulation'] = self.simulate_trading_scenario() - - # Test 4: Threshold Adjustment - test_results['threshold_adjustment'] = self.test_threshold_adjustment() - - # Test 5: Dashboard Integration - test_results['dashboard_integration'] = self.test_dashboard_integration() - - # Test 6: DQN Training Simulation - test_results['dqn_training'] = self.test_dqn_training_simulation() - - # Summary - logger.info("=" * 60) - logger.info("๐Ÿ Test Suite Results:") - - passed_tests = 0 - total_tests = len(test_results) - - for test_name, result in test_results.items(): - status = "โœ… PASSED" if result else "โŒ FAILED" - logger.info(f" {test_name}: {status}") - if result: - passed_tests += 1 - - success_rate = (passed_tests / total_tests) * 100 - logger.info(f"Overall success rate: {success_rate:.1f}% ({passed_tests}/{total_tests})") - - if success_rate >= 80: - logger.info("๐ŸŽ‰ Test suite PASSED! Sensitivity learning system is working correctly.") - else: - logger.warning("โš ๏ธ Test suite FAILED! Some issues need to be addressed.") - - return success_rate >= 80 - -async def main(): - """Main test function""" - tester = SensitivityLearningTester() - - try: - success = await tester.run_all_tests() - - if success: - logger.info("โœ… All tests passed! The sensitivity learning system is ready for production.") - else: - logger.error("โŒ Some tests failed. Please review the issues above.") - - return success - - except Exception as e: - logger.error(f"Test suite failed with exception: {e}") - return False - -if __name__ == "__main__": - # Run the test suite - result = asyncio.run(main()) - - if result: - print("\n๐ŸŽฏ SENSITIVITY LEARNING SYSTEM READY!") - print("Features verified:") - print(" โœ… DQN RL-based sensitivity learning from completed trades") - print(" โœ… 300s data preloading for faster initial performance") - print(" โœ… Dynamic threshold adjustment (lower for closing positions)") - print(" โœ… Color-coded position display ([LONG] green, [SHORT] red)") - print(" โœ… Enhanced model training status with sensitivity info") - print("\nYou can now run the dashboard with these enhanced features!") - else: - print("\nโŒ SOME TESTS FAILED") - print("Please review the test output above and fix any issues.") - - exit(0 if result else 1) \ No newline at end of file diff --git a/tests/test_session_trading.py b/tests/test_session_trading.py deleted file mode 100644 index 0519ecb..0000000 --- a/tests/test_session_trading.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/tests/test_tick_cache.py b/tests/test_tick_cache.py deleted file mode 100644 index a752ddb..0000000 --- a/tests/test_tick_cache.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify tick caching with timestamp serialization -""" - -import logging -import sys -import os -import pandas as pd -from datetime import datetime - -# Add the project root to the path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from dataprovider_realtime import TickStorage - -# Set up logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_tick_caching(): - """Test tick caching with pandas Timestamps""" - logger.info("Testing tick caching with timestamp serialization...") - - try: - # Create tick storage - tick_storage = TickStorage("TEST/SYMBOL", ["1s", "1m"]) - - # Clear any existing cache - if os.path.exists(tick_storage.cache_path): - os.remove(tick_storage.cache_path) - logger.info("Cleared existing cache file") - - # Add some test ticks with different timestamp formats - test_ticks = [ - { - 'price': 100.0, - 'quantity': 1.0, - 'timestamp': pd.Timestamp.now() - }, - { - 'price': 101.0, - 'quantity': 1.5, - 'timestamp': datetime.now() - }, - { - 'price': 102.0, - 'quantity': 2.0, - 'timestamp': int(datetime.now().timestamp() * 1000) # milliseconds - } - ] - - # Add ticks - for i, tick in enumerate(test_ticks): - logger.info(f"Adding tick {i+1}: price=${tick['price']}, timestamp type={type(tick['timestamp'])}") - tick_storage.add_tick(tick) - - logger.info(f"Total ticks in storage: {len(tick_storage.ticks)}") - - # Force save to cache - tick_storage._save_to_cache() - logger.info("Saved ticks to cache") - - # Verify cache file exists - if os.path.exists(tick_storage.cache_path): - logger.info(f"โœ… Cache file created: {tick_storage.cache_path}") - - # Check file content - with open(tick_storage.cache_path, 'r') as f: - import json - cache_content = json.load(f) - logger.info(f"Cache contains {len(cache_content)} ticks") - - # Show first tick to verify format - if cache_content: - first_tick = cache_content[0] - logger.info(f"First tick in cache: {first_tick}") - logger.info(f"Timestamp type in cache: {type(first_tick['timestamp'])}") - else: - logger.error("โŒ Cache file was not created") - return False - - # Create new tick storage instance to test loading - logger.info("Creating new TickStorage instance to test loading...") - new_tick_storage = TickStorage("TEST/SYMBOL", ["1s", "1m"]) - - # Load from cache - cache_loaded = new_tick_storage._load_from_cache() - - if cache_loaded: - logger.info(f"โœ… Successfully loaded {len(new_tick_storage.ticks)} ticks from cache") - - # Verify timestamps are properly converted back to pandas Timestamps - for i, tick in enumerate(new_tick_storage.ticks): - logger.info(f"Loaded tick {i+1}: price=${tick['price']}, timestamp={tick['timestamp']}, type={type(tick['timestamp'])}") - - if not isinstance(tick['timestamp'], pd.Timestamp): - logger.error(f"โŒ Timestamp not properly converted back to pandas.Timestamp: {type(tick['timestamp'])}") - return False - - logger.info("โœ… All timestamps properly converted back to pandas.Timestamp") - return True - else: - logger.error("โŒ Failed to load ticks from cache") - return False - - except Exception as e: - logger.error(f"โŒ Error in tick caching test: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return False - -def main(): - """Run the test""" - logger.info("๐Ÿงช Starting tick caching test...") - logger.info("=" * 50) - - success = test_tick_caching() - - logger.info("\n" + "=" * 50) - if success: - logger.info("๐ŸŽ‰ Tick caching test PASSED!") - else: - logger.error("โŒ Tick caching test FAILED!") - - return success - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/test_tick_processor_final.py b/tests/test_tick_processor_final.py deleted file mode 100644 index 43ce319..0000000 --- a/tests/test_tick_processor_final.py +++ /dev/null @@ -1,310 +0,0 @@ -#!/usr/bin/env python3 -""" -Final Real-Time Tick Processor Test - -This script demonstrates that the Neural Network Real-Time Tick Processing Module -is working correctly as a DPS alternative for processing tick data with volume information. -""" - -import logging -import sys -import numpy as np -from pathlib import Path -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.realtime_tick_processor import ( - RealTimeTickProcessor, - ProcessedTickFeatures, - TickData, - create_realtime_tick_processor -) - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def demonstrate_neural_dps_alternative(): - """Demonstrate the Neural DPS alternative functionality""" - logger.info("="*80) - logger.info("๐Ÿš€ NEURAL DPS ALTERNATIVE DEMONSTRATION") - logger.info("="*80) - - try: - # Create tick processor - logger.info("\n๐Ÿ“Š STEP 1: Initialize Neural DPS Alternative") - logger.info("-" * 50) - - symbols = ['ETH/USDT', 'BTC/USDT'] - tick_processor = create_realtime_tick_processor(symbols) - - logger.info("โœ… Neural DPS Alternative initialized successfully") - logger.info(f" Symbols: {tick_processor.symbols}") - logger.info(f" Processing device: {tick_processor.device}") - logger.info(f" Neural network architecture: TickProcessingNN") - logger.info(f" Input features per tick: 9") - logger.info(f" Output neural features: 64") - logger.info(f" Processing window: {tick_processor.processing_window} ticks") - - # Generate realistic market tick data - logger.info("\n๐Ÿ“ˆ STEP 2: Generate Realistic Market Tick Data") - logger.info("-" * 50) - - def generate_realistic_ticks(symbol: str, count: int = 100): - """Generate realistic tick data with volume information""" - ticks = [] - base_price = 3500.0 if 'ETH' in symbol else 65000.0 - base_time = datetime.now() - - for i in range(count): - # Simulate realistic price movement with micro-trends - if i % 20 < 10: # Uptrend phase - price_change = np.random.normal(0.0002, 0.0008) - else: # Downtrend phase - price_change = np.random.normal(-0.0002, 0.0008) - - price = base_price * (1 + price_change) - - # Simulate realistic volume distribution - if abs(price_change) > 0.001: # Large price moves get more volume - volume = np.random.exponential(0.5) - else: - volume = np.random.exponential(0.1) - - # Market maker vs taker dynamics - side = 'buy' if price_change > 0 else 'sell' - if np.random.random() < 0.3: # 30% chance to flip - side = 'sell' if side == 'buy' else 'buy' - - tick = TickData( - timestamp=base_time, - price=price, - volume=volume, - side=side, - trade_id=f"{symbol}_{i}" - ) - - ticks.append(tick) - base_price = price - - return ticks - - # Generate ticks for both symbols - eth_ticks = generate_realistic_ticks('ETH/USDT', 100) - btc_ticks = generate_realistic_ticks('BTC/USDT', 100) - - logger.info(f"โœ… Generated realistic market data:") - logger.info(f" ETH/USDT: {len(eth_ticks)} ticks") - logger.info(f" Price range: ${min(t.price for t in eth_ticks):.2f} - ${max(t.price for t in eth_ticks):.2f}") - logger.info(f" Volume range: {min(t.volume for t in eth_ticks):.4f} - {max(t.volume for t in eth_ticks):.4f}") - logger.info(f" BTC/USDT: {len(btc_ticks)} ticks") - logger.info(f" Price range: ${min(t.price for t in btc_ticks):.2f} - ${max(t.price for t in btc_ticks):.2f}") - - # Process ticks through Neural DPS - logger.info("\n๐Ÿง  STEP 3: Neural Network Processing") - logger.info("-" * 50) - - # Add ticks to processor buffers - with tick_processor.data_lock: - for tick in eth_ticks: - tick_processor.tick_buffers['ETH/USDT'].append(tick) - for tick in btc_ticks: - tick_processor.tick_buffers['BTC/USDT'].append(tick) - - # Process through neural network - eth_features = tick_processor._extract_neural_features('ETH/USDT') - btc_features = tick_processor._extract_neural_features('BTC/USDT') - - logger.info("โœ… Neural network processing completed:") - - if eth_features: - logger.info(f" ETH/USDT processed features:") - logger.info(f" Neural features: {eth_features.neural_features.shape} (confidence: {eth_features.confidence:.3f})") - logger.info(f" Price features: {eth_features.price_features.shape}") - logger.info(f" Volume features: {eth_features.volume_features.shape}") - logger.info(f" Microstructure features: {eth_features.microstructure_features.shape}") - - if btc_features: - logger.info(f" BTC/USDT processed features:") - logger.info(f" Neural features: {btc_features.neural_features.shape} (confidence: {btc_features.confidence:.3f})") - logger.info(f" Price features: {btc_features.price_features.shape}") - logger.info(f" Volume features: {btc_features.volume_features.shape}") - logger.info(f" Microstructure features: {btc_features.microstructure_features.shape}") - - # Demonstrate volume analysis - logger.info("\n๐Ÿ’ฐ STEP 4: Volume Analysis Capabilities") - logger.info("-" * 50) - - if eth_features: - volume_features = eth_features.volume_features - logger.info("โœ… Volume analysis extracted:") - logger.info(f" Total volume: {volume_features[0]:.4f}") - logger.info(f" Average volume: {volume_features[1]:.4f}") - logger.info(f" Volume volatility: {volume_features[2]:.4f}") - logger.info(f" Buy volume: {volume_features[3]:.4f}") - logger.info(f" Sell volume: {volume_features[4]:.4f}") - logger.info(f" Volume imbalance: {volume_features[5]:.4f}") - logger.info(f" VWAP deviation: {volume_features[6]:.4f}") - - # Demonstrate microstructure analysis - logger.info("\n๐Ÿ”ฌ STEP 5: Market Microstructure Analysis") - logger.info("-" * 50) - - if eth_features: - micro_features = eth_features.microstructure_features - logger.info("โœ… Microstructure analysis extracted:") - logger.info(f" Trade frequency: {micro_features[0]:.2f} trades/sec") - logger.info(f" Price impact: {micro_features[1]:.6f}") - logger.info(f" Bid-ask spread proxy: {micro_features[2]:.6f}") - logger.info(f" Order flow imbalance: {micro_features[3]:.4f}") - - # Demonstrate real-time feature streaming - logger.info("\n๐Ÿ“ก STEP 6: Real-Time Feature Streaming") - logger.info("-" * 50) - - received_features = [] - - def feature_callback(symbol: str, features: ProcessedTickFeatures): - """Callback to receive real-time features""" - received_features.append((symbol, features)) - logger.info(f"๐Ÿ“จ Received real-time features for {symbol}") - logger.info(f" Confidence: {features.confidence:.3f}") - logger.info(f" Neural features: {len(features.neural_features)} dimensions") - logger.info(f" Timestamp: {features.timestamp}") - - # Add subscriber and simulate feature streaming - tick_processor.add_feature_subscriber(feature_callback) - - # Manually trigger feature processing to simulate streaming - tick_processor._notify_feature_subscribers('ETH/USDT', eth_features) - tick_processor._notify_feature_subscribers('BTC/USDT', btc_features) - - logger.info(f"โœ… Feature streaming demonstrated: {len(received_features)} features received") - - # Performance metrics - logger.info("\nโšก STEP 7: Performance Metrics") - logger.info("-" * 50) - - stats = tick_processor.get_processing_stats() - logger.info("โœ… Performance metrics:") - logger.info(f" Symbols processed: {len(stats['symbols'])}") - logger.info(f" Buffer utilization: {stats['buffer_sizes']}") - logger.info(f" Feature subscribers: {stats['subscribers']}") - logger.info(f" Neural network device: {tick_processor.device}") - - # Demonstrate integration readiness - logger.info("\n๐Ÿ”— STEP 8: Model Integration Readiness") - logger.info("-" * 50) - - logger.info("โœ… Integration capabilities verified:") - logger.info(" โœ“ Feature subscriber system for real-time streaming") - logger.info(" โœ“ Standardized ProcessedTickFeatures format") - logger.info(" โœ“ Neural network feature extraction (64 dimensions)") - logger.info(" โœ“ Volume-weighted analysis") - logger.info(" โœ“ Market microstructure detection") - logger.info(" โœ“ Confidence scoring for feature quality") - logger.info(" โœ“ Multi-symbol processing") - logger.info(" โœ“ Thread-safe data handling") - - return True - - except Exception as e: - logger.error(f"โŒ Neural DPS demonstration failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -def demonstrate_dqn_compatibility(): - """Demonstrate compatibility with DQN models""" - logger.info("\n๐Ÿค– STEP 9: DQN Model Compatibility") - logger.info("-" * 50) - - try: - # Create mock tick features in the format DQN expects - mock_tick_features = { - 'neural_features': np.random.rand(64) * 0.1, - 'volume_features': np.array([1.2, 0.8, 0.15, 850.5, 720.3, 0.05, 0.02]), - 'microstructure_features': np.array([12.5, 0.3, 0.001, 0.1]), - 'confidence': 0.85 - } - - logger.info("โœ… DQN-compatible feature format created:") - logger.info(f" Neural features: {len(mock_tick_features['neural_features'])} dimensions") - logger.info(f" Volume features: {len(mock_tick_features['volume_features'])} dimensions") - logger.info(f" Microstructure features: {len(mock_tick_features['microstructure_features'])} dimensions") - logger.info(f" Confidence score: {mock_tick_features['confidence']}") - - # Demonstrate feature integration - logger.info("\nโœ… Ready for DQN integration:") - logger.info(" โœ“ update_realtime_tick_features() method available") - logger.info(" โœ“ State enhancement with tick features") - logger.info(" โœ“ Weighted feature integration (configurable weight)") - logger.info(" โœ“ Real-time decision enhancement") - - return True - - except Exception as e: - logger.error(f"โŒ DQN compatibility test failed: {e}") - return False - -def main(): - """Main demonstration function""" - logger.info("๐Ÿš€ Starting Neural DPS Alternative Demonstration...") - - # Demonstrate core functionality - neural_success = demonstrate_neural_dps_alternative() - - # Demonstrate DQN compatibility - dqn_success = demonstrate_dqn_compatibility() - - # Final summary - logger.info("\n" + "="*80) - logger.info("๐ŸŽ‰ NEURAL DPS ALTERNATIVE DEMONSTRATION COMPLETE") - logger.info("="*80) - - if neural_success and dqn_success: - logger.info("โœ… ALL DEMONSTRATIONS SUCCESSFUL!") - logger.info("") - logger.info("๐ŸŽฏ NEURAL DPS ALTERNATIVE VERIFIED:") - logger.info(" โœ“ Real-time tick data processing with volume information") - logger.info(" โœ“ Neural network feature extraction (64-dimensional)") - logger.info(" โœ“ Volume-weighted price analysis") - logger.info(" โœ“ Market microstructure pattern detection") - logger.info(" โœ“ Ultra-low latency processing capability") - logger.info(" โœ“ Real-time feature streaming to models") - logger.info(" โœ“ Multi-symbol processing (ETH/USDT, BTC/USDT)") - logger.info(" โœ“ DQN model integration ready") - logger.info("") - logger.info("๐Ÿš€ YOUR NEURAL DPS ALTERNATIVE IS FULLY OPERATIONAL!") - logger.info("") - logger.info("๐Ÿ“‹ WHAT THIS SYSTEM PROVIDES:") - logger.info(" โ€ข Replaces traditional DPS with neural network processing") - logger.info(" โ€ข Processes real-time tick streams with volume information") - logger.info(" โ€ข Extracts sophisticated features for trading models") - logger.info(" โ€ข Provides ultra-low latency for high-frequency trading") - logger.info(" โ€ข Integrates seamlessly with your DQN agents") - logger.info(" โ€ข Supports WebSocket streaming from exchanges") - logger.info(" โ€ข Includes confidence scoring for feature quality") - logger.info("") - logger.info("๐ŸŽฏ NEXT STEPS:") - logger.info(" 1. Connect to live WebSocket feeds (Binance, etc.)") - logger.info(" 2. Start real-time processing with tick_processor.start_processing()") - logger.info(" 3. Your DQN models will receive enhanced tick features automatically") - logger.info(" 4. Monitor performance with get_processing_stats()") - - else: - logger.error("โŒ SOME DEMONSTRATIONS FAILED!") - logger.error(f" Neural DPS: {'โœ…' if neural_success else 'โŒ'}") - logger.error(f" DQN Compatibility: {'โœ…' if dqn_success else 'โŒ'}") - sys.exit(1) - - logger.info("="*80) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/test_tick_processor_simple.py b/tests/test_tick_processor_simple.py deleted file mode 100644 index c6d31d4..0000000 --- a/tests/test_tick_processor_simple.py +++ /dev/null @@ -1,311 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple Real-Time Tick Processor Test - -This script tests the core Neural Network functionality of the Real-Time Tick Processing Module -without requiring live WebSocket connections. -""" - -import logging -import sys -import numpy as np -from pathlib import Path -from datetime import datetime - -# Add project root to path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from core.realtime_tick_processor import ( - RealTimeTickProcessor, - ProcessedTickFeatures, - TickData, - create_realtime_tick_processor -) - -# Setup logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def test_neural_network_functionality(): - """Test the neural network processing without WebSocket connections""" - logger.info("="*80) - logger.info("๐Ÿงช TESTING NEURAL NETWORK TICK PROCESSING") - logger.info("="*80) - - try: - # Test 1: Create tick processor - logger.info("\n๐Ÿ“Š TEST 1: Creating Real-Time Tick Processor") - logger.info("-" * 40) - - symbols = ['ETH/USDT', 'BTC/USDT'] - tick_processor = create_realtime_tick_processor(symbols) - - logger.info("โœ… Tick processor created successfully") - logger.info(f" Symbols: {tick_processor.symbols}") - logger.info(f" Device: {tick_processor.device}") - logger.info(f" Neural network input size: 9") - - # Test 2: Generate mock tick data - logger.info("\n๐Ÿ“ˆ TEST 2: Generating Mock Tick Data") - logger.info("-" * 40) - - # Create realistic mock tick data - mock_ticks = [] - base_price = 3500.0 # ETH price - base_time = datetime.now() - - for i in range(50): # Generate 50 ticks - # Simulate price movement - price_change = np.random.normal(0, 0.001) # Small random changes - price = base_price * (1 + price_change) - - # Simulate volume - volume = np.random.exponential(0.1) # Exponential distribution for volume - - # Random buy/sell - side = 'buy' if np.random.random() > 0.5 else 'sell' - - tick = TickData( - timestamp=base_time, - price=price, - volume=volume, - side=side, - trade_id=f"trade_{i}" - ) - - mock_ticks.append(tick) - base_price = price # Update base price for next tick - - logger.info(f"โœ… Generated {len(mock_ticks)} mock ticks") - logger.info(f" Price range: {min(t.price for t in mock_ticks):.2f} - {max(t.price for t in mock_ticks):.2f}") - logger.info(f" Volume range: {min(t.volume for t in mock_ticks):.4f} - {max(t.volume for t in mock_ticks):.4f}") - - # Test 3: Add ticks to processor buffer - logger.info("\n๐Ÿ’พ TEST 3: Adding Ticks to Processor Buffer") - logger.info("-" * 40) - - symbol = 'ETH/USDT' - with tick_processor.data_lock: - for tick in mock_ticks: - tick_processor.tick_buffers[symbol].append(tick) - - buffer_size = len(tick_processor.tick_buffers[symbol]) - logger.info(f"โœ… Added ticks to buffer: {buffer_size} ticks") - - # Test 4: Extract neural features - logger.info("\n๐Ÿง  TEST 4: Neural Network Feature Extraction") - logger.info("-" * 40) - - features = tick_processor._extract_neural_features(symbol) - - if features is not None: - logger.info("โœ… Neural features extracted successfully") - logger.info(f" Timestamp: {features.timestamp}") - logger.info(f" Confidence: {features.confidence:.3f}") - logger.info(f" Neural features shape: {features.neural_features.shape}") - logger.info(f" Price features shape: {features.price_features.shape}") - logger.info(f" Volume features shape: {features.volume_features.shape}") - logger.info(f" Microstructure features shape: {features.microstructure_features.shape}") - - # Show sample values - logger.info(f" Neural features sample: {features.neural_features[:5]}") - logger.info(f" Price features sample: {features.price_features[:3]}") - logger.info(f" Volume features sample: {features.volume_features[:3]}") - else: - logger.error("โŒ Failed to extract neural features") - return False - - # Test 5: Test feature conversion methods - logger.info("\n๐Ÿ”ง TEST 5: Feature Conversion Methods") - logger.info("-" * 40) - - # Test tick-to-features conversion - tick_features = tick_processor._ticks_to_features(mock_ticks) - logger.info(f"โœ… Tick features converted: shape {tick_features.shape}") - logger.info(f" Expected shape: ({tick_processor.processing_window}, 9)") - - # Test individual feature extraction - price_features = tick_processor._extract_price_features(mock_ticks) - volume_features = tick_processor._extract_volume_features(mock_ticks) - microstructure_features = tick_processor._extract_microstructure_features(mock_ticks) - - logger.info(f"โœ… Price features: {len(price_features)} features") - logger.info(f"โœ… Volume features: {len(volume_features)} features") - logger.info(f"โœ… Microstructure features: {len(microstructure_features)} features") - - # Test 6: Neural network forward pass - logger.info("\nโšก TEST 6: Neural Network Forward Pass") - logger.info("-" * 40) - - import torch - - # Test direct neural network inference - tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(tick_processor.device) - - with torch.no_grad(): - neural_features, confidence = tick_processor.tick_nn(tick_tensor) - - logger.info("โœ… Neural network forward pass successful") - logger.info(f" Input shape: {tick_tensor.shape}") - logger.info(f" Output features shape: {neural_features.shape}") - logger.info(f" Confidence shape: {confidence.shape}") - logger.info(f" Confidence value: {confidence.item():.3f}") - - return True - - except Exception as e: - logger.error(f"โŒ Neural network test failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -def test_dqn_integration(): - """Test DQN integration with real-time tick features""" - logger.info("\n๐Ÿค– TESTING DQN INTEGRATION WITH TICK FEATURES") - logger.info("-" * 50) - - try: - from NN.models.dqn_agent import DQNAgent - import numpy as np - - # Create DQN agent - state_shape = (3, 5) # 3 timeframes, 5 features - dqn = DQNAgent(state_shape=state_shape, n_actions=3) - - logger.info("โœ… DQN agent created") - logger.info(f" State shape: {state_shape}") - logger.info(f" Actions: {dqn.n_actions}") - logger.info(f" Device: {dqn.device}") - logger.info(f" Tick feature weight: {dqn.tick_feature_weight}") - - # Test state enhancement - test_state = np.random.rand(3, 5) - logger.info(f" Test state shape: {test_state.shape}") - - # Simulate realistic tick features - mock_tick_features = { - 'neural_features': np.random.rand(64) * 0.1, # Small neural features - 'volume_features': np.array([1.2, 0.8, 0.15, 850.5, 720.3, 0.05, 0.02]), # Realistic volume features - 'microstructure_features': np.array([12.5, 0.3, 0.001, 0.1]), # Realistic microstructure - 'confidence': 0.85 - } - - # Update DQN with tick features - dqn.update_realtime_tick_features(mock_tick_features) - logger.info("โœ… DQN updated with mock tick features") - - # Test enhanced action selection - action_with_ticks = dqn.act(test_state, explore=False) - logger.info(f"โœ… DQN action with tick features: {action_with_ticks}") - - # Test without tick features - dqn.realtime_tick_features = None - action_without_ticks = dqn.act(test_state, explore=False) - logger.info(f"โœ… DQN action without tick features: {action_without_ticks}") - - # Test state enhancement method directly - dqn.realtime_tick_features = mock_tick_features - enhanced_state = dqn._enhance_state_with_tick_features(test_state) - logger.info(f"โœ… State enhancement test:") - logger.info(f" Original state shape: {test_state.shape}") - logger.info(f" Enhanced state shape: {enhanced_state.shape}") - - logger.info("โœ… DQN integration test completed successfully") - - return True - - except Exception as e: - logger.error(f"โŒ DQN integration test failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -def test_performance_metrics(): - """Test performance and statistics functionality""" - logger.info("\n๐Ÿ“Š TESTING PERFORMANCE METRICS") - logger.info("-" * 40) - - try: - tick_processor = create_realtime_tick_processor(['ETH/USDT']) - - # Test stats without processing - stats = tick_processor.get_processing_stats() - logger.info("โœ… Basic stats retrieved") - logger.info(f" Symbols: {stats['symbols']}") - logger.info(f" Streaming: {stats['streaming']}") - logger.info(f" Tick counts: {stats['tick_counts']}") - logger.info(f" Buffer sizes: {stats['buffer_sizes']}") - logger.info(f" Subscribers: {stats['subscribers']}") - - # Test feature subscriber - received_features = [] - - def test_callback(symbol: str, features: ProcessedTickFeatures): - received_features.append((symbol, features)) - - tick_processor.add_feature_subscriber(test_callback) - logger.info("โœ… Feature subscriber added") - - # Test subscriber removal - tick_processor.remove_feature_subscriber(test_callback) - logger.info("โœ… Feature subscriber removed") - - return True - - except Exception as e: - logger.error(f"โŒ Performance metrics test failed: {e}") - return False - -def main(): - """Main test function""" - logger.info("๐Ÿš€ Starting Simple Real-Time Tick Processor Tests...") - - # Test neural network functionality - nn_success = test_neural_network_functionality() - - # Test DQN integration - dqn_success = test_dqn_integration() - - # Test performance metrics - perf_success = test_performance_metrics() - - # Summary - logger.info("\n" + "="*80) - logger.info("๐ŸŽ‰ SIMPLE TICK PROCESSOR TEST SUMMARY") - logger.info("="*80) - - if nn_success and dqn_success and perf_success: - logger.info("โœ… ALL TESTS PASSED!") - logger.info("") - logger.info("๐Ÿ“‹ VERIFIED FUNCTIONALITY:") - logger.info(" โœ“ Neural network tick processing") - logger.info(" โœ“ Feature extraction (price, volume, microstructure)") - logger.info(" โœ“ DQN integration with tick features") - logger.info(" โœ“ State enhancement for RL models") - logger.info(" โœ“ Performance monitoring") - logger.info("") - logger.info("๐ŸŽฏ NEURAL DPS ALTERNATIVE READY:") - logger.info(" โ€ข Real-time tick processing โœ“") - logger.info(" โ€ข Volume-weighted analysis โœ“") - logger.info(" โ€ข Neural feature extraction โœ“") - logger.info(" โ€ข Model integration ready โœ“") - logger.info("") - logger.info("๐Ÿš€ Your Neural DPS alternative is working correctly!") - logger.info(" The system can now process real-time tick data with volume") - logger.info(" information and feed enhanced features to your DQN models.") - - else: - logger.error("โŒ SOME TESTS FAILED!") - logger.error(f" Neural Network: {'โœ…' if nn_success else 'โŒ'}") - logger.error(f" DQN Integration: {'โœ…' if dqn_success else 'โŒ'}") - logger.error(f" Performance: {'โœ…' if perf_success else 'โŒ'}") - sys.exit(1) - - logger.info("="*80) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/utils/reward_calculator.py b/utils/reward_calculator.py index d0c2b9a..9c2002f 100644 --- a/utils/reward_calculator.py +++ b/utils/reward_calculator.py @@ -9,215 +9,166 @@ rewards for successful holding of positions. import numpy as np from datetime import datetime, timedelta from collections import deque +import logging -class ImprovedRewardCalculator: - def __init__(self, - max_drawdown_pct=0.1, # Maximum drawdown % - risk_reward_ratio=1.5, # Risk-reward ratio - base_fee_rate=0.0002, # 0.02% per transaction - max_frequency_penalty=0.005, # Maximum 0.5% penalty for frequent trading - holding_reward_rate=0.0001, # Small reward for holding profitable positions - risk_adjusted=True, # Use Sharpe ratio for risk adjustment - base_reward=1.0, # Base reward scale - profit_factor=2.0, # Profit reward multiplier - loss_factor=1.0, # Loss penalty multiplier - trade_frequency_penalty=0.3, # Penalty for frequent trading - position_duration_factor=0.05 # Reward for longer positions - ): - +logger = logging.getLogger(__name__) + +class RewardCalculator: + def __init__(self, base_fee_rate=0.001, reward_scaling=10.0, risk_aversion=0.1): self.base_fee_rate = base_fee_rate - self.max_frequency_penalty = max_frequency_penalty - self.holding_reward_rate = holding_reward_rate - self.risk_adjusted = risk_adjusted - - # New parameters - self.base_reward = base_reward - self.profit_factor = profit_factor - self.loss_factor = loss_factor - self.trade_frequency_penalty = trade_frequency_penalty - self.position_duration_factor = position_duration_factor - - # Keep track of recent trades - self.recent_trades = deque(maxlen=1000) - self.trade_pnls = deque(maxlen=100) # For risk adjustment - - # Additional tracking metrics - self.total_trades = 0 - self.profitable_trades = 0 - self.total_pnl = 0.0 - self.daily_pnl = {} - self.hourly_pnl = {} - - def record_trade(self, timestamp=None, action=None, price=None): - """Record a trade for frequency tracking""" - if timestamp is None: - timestamp = datetime.now() - - self.recent_trades.append({ - 'timestamp': timestamp, - 'action': action, - 'price': price - }) - + self.reward_scaling = reward_scaling + self.risk_aversion = risk_aversion + self.trade_pnls = [] + self.returns = [] + self.trade_timestamps = [] + self.frequency_threshold = 10 # Trades per minute threshold for penalty + self.max_frequency_penalty = 0.05 + def record_pnl(self, pnl): - """Record a PnL result for risk adjustment and tracking metrics""" + """Record P&L for risk adjustment calculations""" self.trade_pnls.append(pnl) - - # Update overall metrics - self.total_trades += 1 - self.total_pnl += pnl - - if pnl > 0: - self.profitable_trades += 1 - - # Track daily and hourly PnL - now = datetime.now() - day_key = now.strftime('%Y-%m-%d') - hour_key = now.strftime('%Y-%m-%d %H:00') - - # Update daily PnL - if day_key not in self.daily_pnl: - self.daily_pnl[day_key] = 0.0 - self.daily_pnl[day_key] += pnl - - # Update hourly PnL - if hour_key not in self.hourly_pnl: - self.hourly_pnl[hour_key] = 0.0 - self.hourly_pnl[hour_key] += pnl - + if len(self.trade_pnls) > 100: + self.trade_pnls.pop(0) + + def record_trade(self, action): + """Record trade action for frequency penalty calculations""" + from time import time + self.trade_timestamps.append(time()) + if len(self.trade_timestamps) > 100: + self.trade_timestamps.pop(0) + def _calculate_frequency_penalty(self): - """Calculate penalty for trading too frequently""" - if len(self.recent_trades) < 2: + """Calculate penalty for high-frequency trading""" + if len(self.trade_timestamps) < 2: return 0.0 - - # Count trades in the last minute - now = datetime.now() - one_minute_ago = now - timedelta(minutes=1) - trades_last_minute = sum(1 for trade in self.recent_trades - if trade['timestamp'] > one_minute_ago) - - # Apply progressive penalty (more severe as frequency increases) - if trades_last_minute <= 1: - return 0.0 # No penalty for normal trading rate - - # Progressive penalty based on trade frequency - penalty = min(self.max_frequency_penalty, - self.base_fee_rate * trades_last_minute) - - return penalty - - def _calculate_holding_reward(self, position_held_time, price_change_pct): - """Calculate reward for holding a position for some time""" - if position_held_time <= 0 or price_change_pct <= 0: - return 0.0 # No reward for unprofitable holds - - # Cap at 100 time units (seconds, minutes, etc.) - capped_time = min(position_held_time, 100) - - # Scale reward by both time and price change - reward = self.holding_reward_rate * capped_time * price_change_pct - - return reward - + time_span = self.trade_timestamps[-1] - self.trade_timestamps[0] + if time_span <= 0: + return 0.0 + trades_per_minute = (len(self.trade_timestamps) / time_span) * 60 + if trades_per_minute > self.frequency_threshold: + penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001) + return penalty + return 0.0 + def _calculate_risk_adjustment(self, reward): """Adjust rewards based on risk (simple Sharpe ratio implementation)""" if len(self.trade_pnls) < 5: - return reward # Not enough data for adjustment - - # Calculate mean and standard deviation of returns + return reward pnl_array = np.array(self.trade_pnls) mean_return = np.mean(pnl_array) std_return = np.std(pnl_array) - if std_return == 0: - return reward # Avoid division by zero - - # Simplified Sharpe ratio + return reward sharpe = mean_return / std_return - - # Scale reward by Sharpe ratio (normalized to be around 1.0) adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0) - return reward * adjustment_factor - - def calculate_reward(self, action, price_change, position_held_time=0, - volatility=None, is_profitable=False): - """ - Calculate the improved reward - - Args: - action (int): 0 = Buy, 1 = Sell, 2 = Hold - price_change (float): Percent price change for the trade - position_held_time (int): Time position was held (in time units) - volatility (float, optional): Market volatility measure - is_profitable (bool): Whether current position is profitable - - Returns: - float: Calculated reward value - """ - # Calculate trading fee + + def _calculate_holding_reward(self, position_held_time, price_change): + """Calculate reward for holding a position""" + base_holding_reward = 0.0005 * (position_held_time / 60.0) + if price_change > 0: + return base_holding_reward * 2 + elif price_change < 0: + return base_holding_reward * 0.5 + return base_holding_reward + + def calculate_basic_reward(self, pnl, confidence): + """Calculate basic training reward based on P&L and confidence""" + try: + base_reward = pnl + if pnl < 0 and confidence > 0.7: + confidence_adjustment = -confidence * 2 + elif pnl > 0 and confidence > 0.7: + confidence_adjustment = confidence * 1.5 + else: + confidence_adjustment = 0 + final_reward = base_reward + confidence_adjustment + normalized_reward = np.tanh(final_reward / 10.0) + logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}") + return float(normalized_reward) + except Exception as e: + logger.error(f"Error calculating basic reward: {e}") + return 0.0 + + def calculate_enhanced_reward(self, action, price_change, position_held_time=0, volatility=None, is_profitable=False, confidence=0.0, predicted_change=0.0, actual_change=0.0, current_pnl=0.0, symbol='UNKNOWN'): + """Calculate enhanced reward for trading actions""" fee = self.base_fee_rate - - # Calculate frequency penalty frequency_penalty = self._calculate_frequency_penalty() - - # Base reward calculation if action == 0: # Buy - # Small penalty for transaction plus frequency penalty reward = -fee - frequency_penalty - elif action == 1: # Sell - # Calculate profit percentage minus fees (both entry and exit) profit_pct = price_change net_profit = profit_pct - (fee * 2) - - # Scale reward and apply frequency penalty - reward = net_profit * 10 # Scale reward + reward = net_profit * self.reward_scaling reward -= frequency_penalty - - # Record PnL for risk adjustment self.record_pnl(net_profit) - else: # Hold - # Small reward for holding a profitable position, small cost otherwise if is_profitable: reward = self._calculate_holding_reward(position_held_time, price_change) else: - reward = -0.0001 # Very small negative reward - - # Apply risk adjustment if enabled - if self.risk_adjusted: - reward = self._calculate_risk_adjustment(reward) - - # Record this action for future frequency calculations - self.record_trade(action=action) - + reward = -0.0001 + if action in [0, 1] and predicted_change != 0: + if (action == 0 and actual_change > 0) or (action == 1 and actual_change < 0): + reward += abs(actual_change) * 5.0 + else: + reward -= abs(predicted_change) * 2.0 + reward += current_pnl * 0.1 + if volatility is not None: + reward -= abs(volatility) * 100 + if self.risk_aversion > 0 and len(self.returns) > 1: + returns_std = np.std(self.returns) + reward -= returns_std * self.risk_aversion + self.record_trade(action) + return reward + + def calculate_prediction_reward(self, symbol, predicted_direction, actual_direction, confidence, predicted_change, actual_change, current_pnl=0.0, position_duration=0.0): + """Calculate reward for prediction accuracy""" + reward = 0.0 + if predicted_direction == actual_direction: + reward += 1.0 * confidence + else: + reward -= 0.5 + if predicted_direction == actual_direction and abs(predicted_change) > 0.001: + reward += abs(actual_change) * 5.0 + if predicted_direction != actual_direction and abs(predicted_change) > 0.001: + reward -= abs(predicted_change) * 2.0 + reward += current_pnl * 0.1 + # Dynamic adjustment based on recent PnL (loss cutting incentive) + if hasattr(self, 'pnl_history') and symbol in self.pnl_history and self.pnl_history[symbol]: + latest_pnl_entry = self.pnl_history[symbol][-1] + latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0 + if latest_pnl_value < 0 and position_duration > 60: + reward -= (abs(latest_pnl_value) * 0.2) + pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)] + best_pnl = max(pnl_values) if pnl_values else 0.0 + if best_pnl < 0.0: + reward -= 0.1 return reward # Example usage: if __name__ == "__main__": # Create calculator instance - reward_calc = ImprovedRewardCalculator() + reward_calc = RewardCalculator() # Example reward for a buy action - buy_reward = reward_calc.calculate_reward(action=0, price_change=0) + buy_reward = reward_calc.calculate_enhanced_reward(action=0, price_change=0) print(f"Buy action reward: {buy_reward:.5f}") # Record a trade for frequency tracking - reward_calc.record_trade(action=0) + reward_calc.record_trade(0) # Wait a bit and make another trade to test frequency penalty import time time.sleep(0.1) # Example reward for a sell action with profit - sell_reward = reward_calc.calculate_reward(action=1, price_change=0.015, position_held_time=60) + sell_reward = reward_calc.calculate_enhanced_reward(action=1, price_change=0.015, position_held_time=60) print(f"Sell action reward (with profit): {sell_reward:.5f}") # Example reward for a hold action on profitable position - hold_reward = reward_calc.calculate_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True) + hold_reward = reward_calc.calculate_enhanced_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True) print(f"Hold action reward (profitable): {hold_reward:.5f}") # Example reward for a hold action on unprofitable position - hold_reward_neg = reward_calc.calculate_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False) + hold_reward_neg = reward_calc.calculate_enhanced_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False) print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}") \ No newline at end of file diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index c259c98..ba6b6cc 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -5426,9 +5426,24 @@ class CleanTradingDashboard: confidence_target_tensor = torch.FloatTensor([confidence_target]).to(device) network.train() - action_logits, predicted_confidence = network(features_tensor) + network_output = network(features_tensor) + + # Handle different return formats from network + if isinstance(network_output, tuple) and len(network_output) == 2: + action_logits, predicted_confidence = network_output + elif hasattr(network_output, 'dim'): + # Single tensor output - assume it's action logits + action_logits = network_output + predicted_confidence = torch.tensor(0.5, device=device) # Default confidence + else: + logger.debug(f"Unexpected network output format: {type(network_output)}") + continue + + # Ensure predicted_confidence is a tensor with proper dimensions + if not hasattr(predicted_confidence, 'dim'): + # If it's not a tensor, convert it + predicted_confidence = torch.tensor(float(predicted_confidence), device=device) - # Ensure predicted_confidence has a batch dimension if it doesn't already if predicted_confidence.dim() == 0: predicted_confidence = predicted_confidence.unsqueeze(0) @@ -6048,4 +6063,7 @@ def create_clean_dashboard(data_provider: Optional[DataProvider] = None, orchest data_provider=data_provider, orchestrator=orchestrator, trading_executor=trading_executor - ) \ No newline at end of file + ) + + + # test edit \ No newline at end of file diff --git a/web/component_manager.py b/web/component_manager.py index ebb0e0a..069fab3 100644 --- a/web/component_manager.py +++ b/web/component_manager.py @@ -275,18 +275,28 @@ class DashboardComponentManager: def format_cob_data(self, cob_snapshot, symbol, cumulative_imbalance_stats=None): """Format COB data into a split view with summary, imbalance stats, and a compact ladder.""" try: - if not cob_snapshot or not hasattr(cob_snapshot, 'stats'): + if not cob_snapshot: return html.Div([ html.H6(f"{symbol} COB", className="mb-2"), html.P("No COB data available", className="text-muted small") ]) - stats = cob_snapshot.stats if hasattr(cob_snapshot, 'stats') else {} - mid_price = stats.get('mid_price', 0) - spread_bps = stats.get('spread_bps', 0) - imbalance = stats.get('imbalance', 0) - bids = getattr(cob_snapshot, 'consolidated_bids', []) - asks = getattr(cob_snapshot, 'consolidated_asks', []) + # Handle both old format (with stats attribute) and new format (direct attributes) + if hasattr(cob_snapshot, 'stats'): + # Old format with stats attribute + stats = cob_snapshot.stats + mid_price = stats.get('mid_price', 0) + spread_bps = stats.get('spread_bps', 0) + imbalance = stats.get('imbalance', 0) + bids = getattr(cob_snapshot, 'consolidated_bids', []) + asks = getattr(cob_snapshot, 'consolidated_asks', []) + else: + # New COBSnapshot format with direct attributes + mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0) + spread_bps = getattr(cob_snapshot, 'spread_bps', 0) + imbalance = getattr(cob_snapshot, 'liquidity_imbalance', 0) + bids = getattr(cob_snapshot, 'consolidated_bids', []) + asks = getattr(cob_snapshot, 'consolidated_asks', []) if mid_price == 0 or not bids or not asks: return html.Div([ @@ -294,6 +304,17 @@ class DashboardComponentManager: html.P("Awaiting valid order book data...", className="text-muted small") ]) + # Create stats dict for compatibility with existing code + stats = { + 'mid_price': mid_price, + 'spread_bps': spread_bps, + 'imbalance': imbalance, + 'total_bid_liquidity': getattr(cob_snapshot, 'total_bid_liquidity', 0), + 'total_ask_liquidity': getattr(cob_snapshot, 'total_ask_liquidity', 0), + 'bid_levels': len(bids), + 'ask_levels': len(asks) + } + # --- Left Panel: Overview and Stats --- overview_panel = self._create_cob_overview_panel(symbol, stats, cumulative_imbalance_stats) @@ -381,10 +402,19 @@ class DashboardComponentManager: def aggregate_buckets(orders): buckets = {} for order in orders: - price = order.get('price', 0) - # Handle both old format (size) and new format (total_size) - size = order.get('total_size', order.get('size', 0)) - volume_usd = order.get('total_volume_usd', size * price) + # Handle both dictionary format and ConsolidatedOrderBookLevel objects + if hasattr(order, 'price'): + # ConsolidatedOrderBookLevel object + price = order.price + size = order.total_size + volume_usd = order.total_volume_usd + else: + # Dictionary format (legacy) + price = order.get('price', 0) + # Handle both old format (size) and new format (total_size) + size = order.get('total_size', order.get('size', 0)) + volume_usd = order.get('total_volume_usd', size * price) + if price > 0: bucket_key = round(price / bucket_size) * bucket_size if bucket_key not in buckets: