175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test CNN Integration
|
|
|
|
This script tests if the CNN adapter is working properly and identifies issues.
|
|
"""
|
|
|
|
import logging
|
|
import sys
|
|
import os
|
|
from datetime import datetime
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_cnn_adapter():
|
|
"""Test CNN adapter initialization and basic functionality"""
|
|
try:
|
|
logger.info("Testing CNN adapter initialization...")
|
|
|
|
# Test 1: Import CNN adapter
|
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
logger.info("✅ CNN adapter import successful")
|
|
|
|
# Test 2: Initialize CNN adapter
|
|
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
logger.info("✅ CNN adapter initialization successful")
|
|
|
|
# Test 3: Check adapter attributes
|
|
logger.info(f"CNN adapter model: {cnn_adapter.model}")
|
|
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
|
logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}")
|
|
|
|
# Test 4: Check metrics tracking
|
|
logger.info(f"Inference count: {cnn_adapter.inference_count}")
|
|
logger.info(f"Training count: {cnn_adapter.training_count}")
|
|
logger.info(f"Training data length: {len(cnn_adapter.training_data)}")
|
|
|
|
# Test 5: Test simple training sample addition
|
|
cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1)
|
|
logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}")
|
|
|
|
# Test 6: Test training if we have enough samples
|
|
if len(cnn_adapter.training_data) >= 2:
|
|
# Add another sample to have minimum for training
|
|
cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05)
|
|
|
|
# Try training
|
|
training_result = cnn_adapter.train(epochs=1)
|
|
logger.info(f"✅ Training successful: {training_result}")
|
|
|
|
# Check if metrics were updated
|
|
logger.info(f"Last training time: {cnn_adapter.last_training_time}")
|
|
logger.info(f"Last training loss: {cnn_adapter.last_training_loss}")
|
|
logger.info(f"Training count: {cnn_adapter.training_count}")
|
|
else:
|
|
logger.info("⚠️ Not enough training samples for training test")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ CNN adapter test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_base_data_input():
|
|
"""Test BaseDataInput creation"""
|
|
try:
|
|
logger.info("Testing BaseDataInput creation...")
|
|
|
|
# Test 1: Import BaseDataInput
|
|
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
|
logger.info("✅ BaseDataInput import successful")
|
|
|
|
# Test 2: Create sample OHLCV bars
|
|
sample_bars = []
|
|
for i in range(10): # Create 10 sample bars
|
|
bar = OHLCVBar(
|
|
symbol="ETH/USDT",
|
|
timestamp=datetime.now(),
|
|
open=3500.0 + i,
|
|
high=3510.0 + i,
|
|
low=3490.0 + i,
|
|
close=3505.0 + i,
|
|
volume=1000.0,
|
|
timeframe="1s"
|
|
)
|
|
sample_bars.append(bar)
|
|
|
|
logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars")
|
|
|
|
# Test 3: Create BaseDataInput
|
|
base_data = BaseDataInput(
|
|
symbol="ETH/USDT",
|
|
timestamp=datetime.now(),
|
|
ohlcv_1s=sample_bars,
|
|
ohlcv_1m=sample_bars,
|
|
ohlcv_1h=sample_bars,
|
|
ohlcv_1d=sample_bars,
|
|
btc_ohlcv_1s=sample_bars
|
|
)
|
|
|
|
logger.info("✅ BaseDataInput created successfully")
|
|
|
|
# Test 4: Validate BaseDataInput
|
|
is_valid = base_data.validate()
|
|
logger.info(f"BaseDataInput validation: {is_valid}")
|
|
|
|
# Test 5: Get feature vector
|
|
feature_vector = base_data.get_feature_vector()
|
|
logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}")
|
|
|
|
return base_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ BaseDataInput test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
def test_cnn_prediction():
|
|
"""Test CNN prediction with BaseDataInput"""
|
|
try:
|
|
logger.info("Testing CNN prediction...")
|
|
|
|
# Get CNN adapter and base data
|
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
|
|
base_data = test_base_data_input()
|
|
if not base_data:
|
|
logger.error("❌ Cannot test prediction without valid BaseDataInput")
|
|
return False
|
|
|
|
# Test prediction
|
|
model_output = cnn_adapter.predict(base_data)
|
|
logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})")
|
|
|
|
# Check if metrics were updated
|
|
logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}")
|
|
logger.info(f"Last inference time: {cnn_adapter.last_inference_time}")
|
|
logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ CNN prediction test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def main():
|
|
"""Run all tests"""
|
|
logger.info("🧪 Starting CNN Integration Tests")
|
|
|
|
# Test 1: CNN Adapter
|
|
if not test_cnn_adapter():
|
|
logger.error("❌ CNN adapter test failed - stopping")
|
|
return False
|
|
|
|
# Test 2: CNN Prediction
|
|
if not test_cnn_prediction():
|
|
logger.error("❌ CNN prediction test failed - stopping")
|
|
return False
|
|
|
|
logger.info("✅ All CNN integration tests passed!")
|
|
logger.info("🎯 The CNN adapter should now work properly in the dashboard")
|
|
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
success = main()
|
|
sys.exit(0 if success else 1) |