Files
gogo2/test_enhanced_inference_logging.py
2025-07-26 23:34:36 +03:00

193 lines
7.3 KiB
Python

#!/usr/bin/env python3
"""
Test Enhanced Inference Logging
This script tests the enhanced inference logging system that stores
full input features for training feedback.
"""
import sys
import os
import logging
import numpy as np
from datetime import datetime
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import BaseDataInput, OHLCVBar
from utils.database_manager import get_database_manager
from utils.inference_logger import get_inference_logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_test_base_data():
"""Create test BaseDataInput with realistic data"""
# Create OHLCV bars for different timeframes
def create_ohlcv_bars(symbol, timeframe, count=300):
bars = []
base_price = 3000.0 if 'ETH' in symbol else 50000.0
for i in range(count):
price = base_price + np.random.normal(0, base_price * 0.01)
bars.append(OHLCVBar(
symbol=symbol,
timestamp=datetime.now(),
open=price,
high=price * 1.002,
low=price * 0.998,
close=price + np.random.normal(0, price * 0.005),
volume=np.random.uniform(100, 1000),
timeframe=timeframe
))
return bars
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=create_ohlcv_bars("ETH/USDT", "1s", 300),
ohlcv_1m=create_ohlcv_bars("ETH/USDT", "1m", 300),
ohlcv_1h=create_ohlcv_bars("ETH/USDT", "1h", 300),
ohlcv_1d=create_ohlcv_bars("ETH/USDT", "1d", 300),
btc_ohlcv_1s=create_ohlcv_bars("BTC/USDT", "1s", 300),
technical_indicators={
'rsi': 45.5,
'macd': 0.12,
'bb_upper': 3100.0,
'bb_lower': 2900.0,
'volume_ma': 500.0
}
)
return base_data
def test_enhanced_inference_logging():
"""Test the enhanced inference logging system"""
logger.info("=== Testing Enhanced Inference Logging ===")
try:
# Initialize CNN adapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info("✅ CNN adapter initialized")
# Create test data
base_data = create_test_base_data()
logger.info("✅ Test data created")
# Make a prediction (this should log inference data)
logger.info("Making prediction...")
model_output = cnn_adapter.predict(base_data)
logger.info(f"✅ Prediction made: {model_output.predictions['action']} (confidence: {model_output.confidence:.3f})")
# Verify inference was logged to database
db_manager = get_database_manager()
recent_inferences = db_manager.get_recent_inferences(cnn_adapter.model_name, limit=1)
if recent_inferences:
latest_inference = recent_inferences[0]
logger.info(f"✅ Inference logged to database:")
logger.info(f" Model: {latest_inference.model_name}")
logger.info(f" Action: {latest_inference.action}")
logger.info(f" Confidence: {latest_inference.confidence:.3f}")
logger.info(f" Processing time: {latest_inference.processing_time_ms:.1f}ms")
logger.info(f" Has input features: {latest_inference.input_features is not None}")
if latest_inference.input_features is not None:
logger.info(f" Input features shape: {latest_inference.input_features.shape}")
logger.info(f" Input features sample: {latest_inference.input_features[:5]}")
else:
logger.error("❌ No inference records found in database")
return False
# Test training data loading from inference history
logger.info("Testing training data loading from inference history...")
original_training_count = len(cnn_adapter.training_data)
cnn_adapter._load_training_data_from_inference_history()
new_training_count = len(cnn_adapter.training_data)
logger.info(f"✅ Training data loaded: {original_training_count} -> {new_training_count} samples")
# Test prediction evaluation
logger.info("Testing prediction evaluation...")
evaluation_metrics = cnn_adapter.evaluate_predictions_against_outcomes(hours_back=1)
logger.info(f"✅ Evaluation metrics: {evaluation_metrics}")
# Test training with inference data
if new_training_count >= cnn_adapter.batch_size:
logger.info("Testing training with inference data...")
training_metrics = cnn_adapter.train(epochs=1)
logger.info(f"✅ Training completed: {training_metrics}")
else:
logger.info("⚠️ Not enough training data for training test")
return True
except Exception as e:
logger.error(f"❌ Test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_database_query_methods():
"""Test the new database query methods"""
logger.info("=== Testing Database Query Methods ===")
try:
db_manager = get_database_manager()
# Test getting inference records for training
training_records = db_manager.get_inference_records_for_training(
model_name="enhanced_cnn",
hours_back=24,
limit=10
)
logger.info(f"✅ Found {len(training_records)} training records")
for i, record in enumerate(training_records[:3]): # Show first 3
logger.info(f" Record {i+1}:")
logger.info(f" Action: {record.action}")
logger.info(f" Confidence: {record.confidence:.3f}")
logger.info(f" Has features: {record.input_features is not None}")
if record.input_features is not None:
logger.info(f" Features shape: {record.input_features.shape}")
return True
except Exception as e:
logger.error(f"❌ Database query test failed: {e}")
return False
def main():
"""Run all tests"""
logger.info("Starting Enhanced Inference Logging Tests")
# Test 1: Enhanced inference logging
test1_passed = test_enhanced_inference_logging()
# Test 2: Database query methods
test2_passed = test_database_query_methods()
# Summary
logger.info("=== Test Summary ===")
logger.info(f"Enhanced Inference Logging: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"Database Query Methods: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Enhanced inference logging is working correctly.")
logger.info("The system now:")
logger.info(" - Stores full input features with each inference")
logger.info(" - Can retrieve inference data for training feedback")
logger.info(" - Supports continuous learning from inference history")
logger.info(" - Evaluates prediction accuracy over time")
else:
logger.error("❌ Some tests failed. Please check the implementation.")
if __name__ == "__main__":
main()