#!/usr/bin/env python3 """ Test Enhanced Inference Logging This script tests the enhanced inference logging system that stores full input features for training feedback. """ import sys import os import logging import numpy as np from datetime import datetime # Add project root to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from core.enhanced_cnn_adapter import EnhancedCNNAdapter from core.data_models import BaseDataInput, OHLCVBar from utils.database_manager import get_database_manager from utils.inference_logger import get_inference_logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def create_test_base_data(): """Create test BaseDataInput with realistic data""" # Create OHLCV bars for different timeframes def create_ohlcv_bars(symbol, timeframe, count=300): bars = [] base_price = 3000.0 if 'ETH' in symbol else 50000.0 for i in range(count): price = base_price + np.random.normal(0, base_price * 0.01) bars.append(OHLCVBar( symbol=symbol, timestamp=datetime.now(), open=price, high=price * 1.002, low=price * 0.998, close=price + np.random.normal(0, price * 0.005), volume=np.random.uniform(100, 1000), timeframe=timeframe )) return bars base_data = BaseDataInput( symbol="ETH/USDT", timestamp=datetime.now(), ohlcv_1s=create_ohlcv_bars("ETH/USDT", "1s", 300), ohlcv_1m=create_ohlcv_bars("ETH/USDT", "1m", 300), ohlcv_1h=create_ohlcv_bars("ETH/USDT", "1h", 300), ohlcv_1d=create_ohlcv_bars("ETH/USDT", "1d", 300), btc_ohlcv_1s=create_ohlcv_bars("BTC/USDT", "1s", 300), technical_indicators={ 'rsi': 45.5, 'macd': 0.12, 'bb_upper': 3100.0, 'bb_lower': 2900.0, 'volume_ma': 500.0 } ) return base_data def test_enhanced_inference_logging(): """Test the enhanced inference logging system""" logger.info("=== Testing Enhanced Inference Logging ===") try: # Initialize CNN adapter cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") logger.info("✅ CNN adapter initialized") # Create test data base_data = create_test_base_data() logger.info("✅ Test data created") # Make a prediction (this should log inference data) logger.info("Making prediction...") model_output = cnn_adapter.predict(base_data) logger.info(f"✅ Prediction made: {model_output.predictions['action']} (confidence: {model_output.confidence:.3f})") # Verify inference was logged to database db_manager = get_database_manager() recent_inferences = db_manager.get_recent_inferences(cnn_adapter.model_name, limit=1) if recent_inferences: latest_inference = recent_inferences[0] logger.info(f"✅ Inference logged to database:") logger.info(f" Model: {latest_inference.model_name}") logger.info(f" Action: {latest_inference.action}") logger.info(f" Confidence: {latest_inference.confidence:.3f}") logger.info(f" Processing time: {latest_inference.processing_time_ms:.1f}ms") logger.info(f" Has input features: {latest_inference.input_features is not None}") if latest_inference.input_features is not None: logger.info(f" Input features shape: {latest_inference.input_features.shape}") logger.info(f" Input features sample: {latest_inference.input_features[:5]}") else: logger.error("❌ No inference records found in database") return False # Test training data loading from inference history logger.info("Testing training data loading from inference history...") original_training_count = len(cnn_adapter.training_data) cnn_adapter._load_training_data_from_inference_history() new_training_count = len(cnn_adapter.training_data) logger.info(f"✅ Training data loaded: {original_training_count} -> {new_training_count} samples") # Test prediction evaluation logger.info("Testing prediction evaluation...") evaluation_metrics = cnn_adapter.evaluate_predictions_against_outcomes(hours_back=1) logger.info(f"✅ Evaluation metrics: {evaluation_metrics}") # Test training with inference data if new_training_count >= cnn_adapter.batch_size: logger.info("Testing training with inference data...") training_metrics = cnn_adapter.train(epochs=1) logger.info(f"✅ Training completed: {training_metrics}") else: logger.info("⚠️ Not enough training data for training test") return True except Exception as e: logger.error(f"❌ Test failed: {e}") import traceback traceback.print_exc() return False def test_database_query_methods(): """Test the new database query methods""" logger.info("=== Testing Database Query Methods ===") try: db_manager = get_database_manager() # Test getting inference records for training training_records = db_manager.get_inference_records_for_training( model_name="enhanced_cnn", hours_back=24, limit=10 ) logger.info(f"✅ Found {len(training_records)} training records") for i, record in enumerate(training_records[:3]): # Show first 3 logger.info(f" Record {i+1}:") logger.info(f" Action: {record.action}") logger.info(f" Confidence: {record.confidence:.3f}") logger.info(f" Has features: {record.input_features is not None}") if record.input_features is not None: logger.info(f" Features shape: {record.input_features.shape}") return True except Exception as e: logger.error(f"❌ Database query test failed: {e}") return False def main(): """Run all tests""" logger.info("Starting Enhanced Inference Logging Tests") # Test 1: Enhanced inference logging test1_passed = test_enhanced_inference_logging() # Test 2: Database query methods test2_passed = test_database_query_methods() # Summary logger.info("=== Test Summary ===") logger.info(f"Enhanced Inference Logging: {'✅ PASSED' if test1_passed else '❌ FAILED'}") logger.info(f"Database Query Methods: {'✅ PASSED' if test2_passed else '❌ FAILED'}") if test1_passed and test2_passed: logger.info("🎉 All tests passed! Enhanced inference logging is working correctly.") logger.info("The system now:") logger.info(" - Stores full input features with each inference") logger.info(" - Can retrieve inference data for training feedback") logger.info(" - Supports continuous learning from inference history") logger.info(" - Evaluates prediction accuracy over time") else: logger.error("❌ Some tests failed. Please check the implementation.") if __name__ == "__main__": main()