inrefence predictions fix
This commit is contained in:
193
test_enhanced_inference_logging.py
Normal file
193
test_enhanced_inference_logging.py
Normal file
@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Inference Logging
|
||||
|
||||
This script tests the enhanced inference logging system that stores
|
||||
full input features for training feedback.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from core.data_models import BaseDataInput, OHLCVBar
|
||||
from utils.database_manager import get_database_manager
|
||||
from utils.inference_logger import get_inference_logger
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_test_base_data():
|
||||
"""Create test BaseDataInput with realistic data"""
|
||||
|
||||
# Create OHLCV bars for different timeframes
|
||||
def create_ohlcv_bars(symbol, timeframe, count=300):
|
||||
bars = []
|
||||
base_price = 3000.0 if 'ETH' in symbol else 50000.0
|
||||
|
||||
for i in range(count):
|
||||
price = base_price + np.random.normal(0, base_price * 0.01)
|
||||
bars.append(OHLCVBar(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
open=price,
|
||||
high=price * 1.002,
|
||||
low=price * 0.998,
|
||||
close=price + np.random.normal(0, price * 0.005),
|
||||
volume=np.random.uniform(100, 1000),
|
||||
timeframe=timeframe
|
||||
))
|
||||
return bars
|
||||
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=create_ohlcv_bars("ETH/USDT", "1s", 300),
|
||||
ohlcv_1m=create_ohlcv_bars("ETH/USDT", "1m", 300),
|
||||
ohlcv_1h=create_ohlcv_bars("ETH/USDT", "1h", 300),
|
||||
ohlcv_1d=create_ohlcv_bars("ETH/USDT", "1d", 300),
|
||||
btc_ohlcv_1s=create_ohlcv_bars("BTC/USDT", "1s", 300),
|
||||
technical_indicators={
|
||||
'rsi': 45.5,
|
||||
'macd': 0.12,
|
||||
'bb_upper': 3100.0,
|
||||
'bb_lower': 2900.0,
|
||||
'volume_ma': 500.0
|
||||
}
|
||||
)
|
||||
|
||||
return base_data
|
||||
|
||||
def test_enhanced_inference_logging():
|
||||
"""Test the enhanced inference logging system"""
|
||||
|
||||
logger.info("=== Testing Enhanced Inference Logging ===")
|
||||
|
||||
try:
|
||||
# Initialize CNN adapter
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
logger.info("✅ CNN adapter initialized")
|
||||
|
||||
# Create test data
|
||||
base_data = create_test_base_data()
|
||||
logger.info("✅ Test data created")
|
||||
|
||||
# Make a prediction (this should log inference data)
|
||||
logger.info("Making prediction...")
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
logger.info(f"✅ Prediction made: {model_output.predictions['action']} (confidence: {model_output.confidence:.3f})")
|
||||
|
||||
# Verify inference was logged to database
|
||||
db_manager = get_database_manager()
|
||||
recent_inferences = db_manager.get_recent_inferences(cnn_adapter.model_name, limit=1)
|
||||
|
||||
if recent_inferences:
|
||||
latest_inference = recent_inferences[0]
|
||||
logger.info(f"✅ Inference logged to database:")
|
||||
logger.info(f" Model: {latest_inference.model_name}")
|
||||
logger.info(f" Action: {latest_inference.action}")
|
||||
logger.info(f" Confidence: {latest_inference.confidence:.3f}")
|
||||
logger.info(f" Processing time: {latest_inference.processing_time_ms:.1f}ms")
|
||||
logger.info(f" Has input features: {latest_inference.input_features is not None}")
|
||||
|
||||
if latest_inference.input_features is not None:
|
||||
logger.info(f" Input features shape: {latest_inference.input_features.shape}")
|
||||
logger.info(f" Input features sample: {latest_inference.input_features[:5]}")
|
||||
else:
|
||||
logger.error("❌ No inference records found in database")
|
||||
return False
|
||||
|
||||
# Test training data loading from inference history
|
||||
logger.info("Testing training data loading from inference history...")
|
||||
original_training_count = len(cnn_adapter.training_data)
|
||||
cnn_adapter._load_training_data_from_inference_history()
|
||||
new_training_count = len(cnn_adapter.training_data)
|
||||
|
||||
logger.info(f"✅ Training data loaded: {original_training_count} -> {new_training_count} samples")
|
||||
|
||||
# Test prediction evaluation
|
||||
logger.info("Testing prediction evaluation...")
|
||||
evaluation_metrics = cnn_adapter.evaluate_predictions_against_outcomes(hours_back=1)
|
||||
logger.info(f"✅ Evaluation metrics: {evaluation_metrics}")
|
||||
|
||||
# Test training with inference data
|
||||
if new_training_count >= cnn_adapter.batch_size:
|
||||
logger.info("Testing training with inference data...")
|
||||
training_metrics = cnn_adapter.train(epochs=1)
|
||||
logger.info(f"✅ Training completed: {training_metrics}")
|
||||
else:
|
||||
logger.info("⚠️ Not enough training data for training test")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_database_query_methods():
|
||||
"""Test the new database query methods"""
|
||||
|
||||
logger.info("=== Testing Database Query Methods ===")
|
||||
|
||||
try:
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Test getting inference records for training
|
||||
training_records = db_manager.get_inference_records_for_training(
|
||||
model_name="enhanced_cnn",
|
||||
hours_back=24,
|
||||
limit=10
|
||||
)
|
||||
|
||||
logger.info(f"✅ Found {len(training_records)} training records")
|
||||
|
||||
for i, record in enumerate(training_records[:3]): # Show first 3
|
||||
logger.info(f" Record {i+1}:")
|
||||
logger.info(f" Action: {record.action}")
|
||||
logger.info(f" Confidence: {record.confidence:.3f}")
|
||||
logger.info(f" Has features: {record.input_features is not None}")
|
||||
if record.input_features is not None:
|
||||
logger.info(f" Features shape: {record.input_features.shape}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Database query test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
|
||||
logger.info("Starting Enhanced Inference Logging Tests")
|
||||
|
||||
# Test 1: Enhanced inference logging
|
||||
test1_passed = test_enhanced_inference_logging()
|
||||
|
||||
# Test 2: Database query methods
|
||||
test2_passed = test_database_query_methods()
|
||||
|
||||
# Summary
|
||||
logger.info("=== Test Summary ===")
|
||||
logger.info(f"Enhanced Inference Logging: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
||||
logger.info(f"Database Query Methods: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Enhanced inference logging is working correctly.")
|
||||
logger.info("The system now:")
|
||||
logger.info(" - Stores full input features with each inference")
|
||||
logger.info(" - Can retrieve inference data for training feedback")
|
||||
logger.info(" - Supports continuous learning from inference history")
|
||||
logger.info(" - Evaluates prediction accuracy over time")
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the implementation.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user