193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Enhanced Inference Logging
|
|
|
|
This script tests the enhanced inference logging system that stores
|
|
full input features for training feedback.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import logging
|
|
import numpy as np
|
|
from datetime import datetime
|
|
|
|
# Add project root to path
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
from core.data_models import BaseDataInput, OHLCVBar
|
|
from utils.database_manager import get_database_manager
|
|
from utils.inference_logger import get_inference_logger
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def create_test_base_data():
|
|
"""Create test BaseDataInput with realistic data"""
|
|
|
|
# Create OHLCV bars for different timeframes
|
|
def create_ohlcv_bars(symbol, timeframe, count=300):
|
|
bars = []
|
|
base_price = 3000.0 if 'ETH' in symbol else 50000.0
|
|
|
|
for i in range(count):
|
|
price = base_price + np.random.normal(0, base_price * 0.01)
|
|
bars.append(OHLCVBar(
|
|
symbol=symbol,
|
|
timestamp=datetime.now(),
|
|
open=price,
|
|
high=price * 1.002,
|
|
low=price * 0.998,
|
|
close=price + np.random.normal(0, price * 0.005),
|
|
volume=np.random.uniform(100, 1000),
|
|
timeframe=timeframe
|
|
))
|
|
return bars
|
|
|
|
base_data = BaseDataInput(
|
|
symbol="ETH/USDT",
|
|
timestamp=datetime.now(),
|
|
ohlcv_1s=create_ohlcv_bars("ETH/USDT", "1s", 300),
|
|
ohlcv_1m=create_ohlcv_bars("ETH/USDT", "1m", 300),
|
|
ohlcv_1h=create_ohlcv_bars("ETH/USDT", "1h", 300),
|
|
ohlcv_1d=create_ohlcv_bars("ETH/USDT", "1d", 300),
|
|
btc_ohlcv_1s=create_ohlcv_bars("BTC/USDT", "1s", 300),
|
|
technical_indicators={
|
|
'rsi': 45.5,
|
|
'macd': 0.12,
|
|
'bb_upper': 3100.0,
|
|
'bb_lower': 2900.0,
|
|
'volume_ma': 500.0
|
|
}
|
|
)
|
|
|
|
return base_data
|
|
|
|
def test_enhanced_inference_logging():
|
|
"""Test the enhanced inference logging system"""
|
|
|
|
logger.info("=== Testing Enhanced Inference Logging ===")
|
|
|
|
try:
|
|
# Initialize CNN adapter
|
|
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
logger.info("✅ CNN adapter initialized")
|
|
|
|
# Create test data
|
|
base_data = create_test_base_data()
|
|
logger.info("✅ Test data created")
|
|
|
|
# Make a prediction (this should log inference data)
|
|
logger.info("Making prediction...")
|
|
model_output = cnn_adapter.predict(base_data)
|
|
logger.info(f"✅ Prediction made: {model_output.predictions['action']} (confidence: {model_output.confidence:.3f})")
|
|
|
|
# Verify inference was logged to database
|
|
db_manager = get_database_manager()
|
|
recent_inferences = db_manager.get_recent_inferences(cnn_adapter.model_name, limit=1)
|
|
|
|
if recent_inferences:
|
|
latest_inference = recent_inferences[0]
|
|
logger.info(f"✅ Inference logged to database:")
|
|
logger.info(f" Model: {latest_inference.model_name}")
|
|
logger.info(f" Action: {latest_inference.action}")
|
|
logger.info(f" Confidence: {latest_inference.confidence:.3f}")
|
|
logger.info(f" Processing time: {latest_inference.processing_time_ms:.1f}ms")
|
|
logger.info(f" Has input features: {latest_inference.input_features is not None}")
|
|
|
|
if latest_inference.input_features is not None:
|
|
logger.info(f" Input features shape: {latest_inference.input_features.shape}")
|
|
logger.info(f" Input features sample: {latest_inference.input_features[:5]}")
|
|
else:
|
|
logger.error("❌ No inference records found in database")
|
|
return False
|
|
|
|
# Test training data loading from inference history
|
|
logger.info("Testing training data loading from inference history...")
|
|
original_training_count = len(cnn_adapter.training_data)
|
|
cnn_adapter._load_training_data_from_inference_history()
|
|
new_training_count = len(cnn_adapter.training_data)
|
|
|
|
logger.info(f"✅ Training data loaded: {original_training_count} -> {new_training_count} samples")
|
|
|
|
# Test prediction evaluation
|
|
logger.info("Testing prediction evaluation...")
|
|
evaluation_metrics = cnn_adapter.evaluate_predictions_against_outcomes(hours_back=1)
|
|
logger.info(f"✅ Evaluation metrics: {evaluation_metrics}")
|
|
|
|
# Test training with inference data
|
|
if new_training_count >= cnn_adapter.batch_size:
|
|
logger.info("Testing training with inference data...")
|
|
training_metrics = cnn_adapter.train(epochs=1)
|
|
logger.info(f"✅ Training completed: {training_metrics}")
|
|
else:
|
|
logger.info("⚠️ Not enough training data for training test")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_database_query_methods():
|
|
"""Test the new database query methods"""
|
|
|
|
logger.info("=== Testing Database Query Methods ===")
|
|
|
|
try:
|
|
db_manager = get_database_manager()
|
|
|
|
# Test getting inference records for training
|
|
training_records = db_manager.get_inference_records_for_training(
|
|
model_name="enhanced_cnn",
|
|
hours_back=24,
|
|
limit=10
|
|
)
|
|
|
|
logger.info(f"✅ Found {len(training_records)} training records")
|
|
|
|
for i, record in enumerate(training_records[:3]): # Show first 3
|
|
logger.info(f" Record {i+1}:")
|
|
logger.info(f" Action: {record.action}")
|
|
logger.info(f" Confidence: {record.confidence:.3f}")
|
|
logger.info(f" Has features: {record.input_features is not None}")
|
|
if record.input_features is not None:
|
|
logger.info(f" Features shape: {record.input_features.shape}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Database query test failed: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Run all tests"""
|
|
|
|
logger.info("Starting Enhanced Inference Logging Tests")
|
|
|
|
# Test 1: Enhanced inference logging
|
|
test1_passed = test_enhanced_inference_logging()
|
|
|
|
# Test 2: Database query methods
|
|
test2_passed = test_database_query_methods()
|
|
|
|
# Summary
|
|
logger.info("=== Test Summary ===")
|
|
logger.info(f"Enhanced Inference Logging: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
|
logger.info(f"Database Query Methods: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
|
|
|
if test1_passed and test2_passed:
|
|
logger.info("🎉 All tests passed! Enhanced inference logging is working correctly.")
|
|
logger.info("The system now:")
|
|
logger.info(" - Stores full input features with each inference")
|
|
logger.info(" - Can retrieve inference data for training feedback")
|
|
logger.info(" - Supports continuous learning from inference history")
|
|
logger.info(" - Evaluates prediction accuracy over time")
|
|
else:
|
|
logger.error("❌ Some tests failed. Please check the implementation.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |