wip cnn training and cob
This commit is contained in:
175
test_cnn_integration.py
Normal file
175
test_cnn_integration.py
Normal file
@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test CNN Integration
|
||||
|
||||
This script tests if the CNN adapter is working properly and identifies issues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_cnn_adapter():
|
||||
"""Test CNN adapter initialization and basic functionality"""
|
||||
try:
|
||||
logger.info("Testing CNN adapter initialization...")
|
||||
|
||||
# Test 1: Import CNN adapter
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
logger.info("✅ CNN adapter import successful")
|
||||
|
||||
# Test 2: Initialize CNN adapter
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
logger.info("✅ CNN adapter initialization successful")
|
||||
|
||||
# Test 3: Check adapter attributes
|
||||
logger.info(f"CNN adapter model: {cnn_adapter.model}")
|
||||
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
||||
logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}")
|
||||
|
||||
# Test 4: Check metrics tracking
|
||||
logger.info(f"Inference count: {cnn_adapter.inference_count}")
|
||||
logger.info(f"Training count: {cnn_adapter.training_count}")
|
||||
logger.info(f"Training data length: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 5: Test simple training sample addition
|
||||
cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1)
|
||||
logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 6: Test training if we have enough samples
|
||||
if len(cnn_adapter.training_data) >= 2:
|
||||
# Add another sample to have minimum for training
|
||||
cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05)
|
||||
|
||||
# Try training
|
||||
training_result = cnn_adapter.train(epochs=1)
|
||||
logger.info(f"✅ Training successful: {training_result}")
|
||||
|
||||
# Check if metrics were updated
|
||||
logger.info(f"Last training time: {cnn_adapter.last_training_time}")
|
||||
logger.info(f"Last training loss: {cnn_adapter.last_training_loss}")
|
||||
logger.info(f"Training count: {cnn_adapter.training_count}")
|
||||
else:
|
||||
logger.info("⚠️ Not enough training samples for training test")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ CNN adapter test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_base_data_input():
|
||||
"""Test BaseDataInput creation"""
|
||||
try:
|
||||
logger.info("Testing BaseDataInput creation...")
|
||||
|
||||
# Test 1: Import BaseDataInput
|
||||
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
||||
logger.info("✅ BaseDataInput import successful")
|
||||
|
||||
# Test 2: Create sample OHLCV bars
|
||||
sample_bars = []
|
||||
for i in range(10): # Create 10 sample bars
|
||||
bar = OHLCVBar(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
open=3500.0 + i,
|
||||
high=3510.0 + i,
|
||||
low=3490.0 + i,
|
||||
close=3505.0 + i,
|
||||
volume=1000.0,
|
||||
timeframe="1s"
|
||||
)
|
||||
sample_bars.append(bar)
|
||||
|
||||
logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars")
|
||||
|
||||
# Test 3: Create BaseDataInput
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=sample_bars,
|
||||
ohlcv_1m=sample_bars,
|
||||
ohlcv_1h=sample_bars,
|
||||
ohlcv_1d=sample_bars,
|
||||
btc_ohlcv_1s=sample_bars
|
||||
)
|
||||
|
||||
logger.info("✅ BaseDataInput created successfully")
|
||||
|
||||
# Test 4: Validate BaseDataInput
|
||||
is_valid = base_data.validate()
|
||||
logger.info(f"BaseDataInput validation: {is_valid}")
|
||||
|
||||
# Test 5: Get feature vector
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}")
|
||||
|
||||
return base_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ BaseDataInput test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def test_cnn_prediction():
|
||||
"""Test CNN prediction with BaseDataInput"""
|
||||
try:
|
||||
logger.info("Testing CNN prediction...")
|
||||
|
||||
# Get CNN adapter and base data
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
base_data = test_base_data_input()
|
||||
if not base_data:
|
||||
logger.error("❌ Cannot test prediction without valid BaseDataInput")
|
||||
return False
|
||||
|
||||
# Test prediction
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})")
|
||||
|
||||
# Check if metrics were updated
|
||||
logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}")
|
||||
logger.info(f"Last inference time: {cnn_adapter.last_inference_time}")
|
||||
logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ CNN prediction test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
logger.info("🧪 Starting CNN Integration Tests")
|
||||
|
||||
# Test 1: CNN Adapter
|
||||
if not test_cnn_adapter():
|
||||
logger.error("❌ CNN adapter test failed - stopping")
|
||||
return False
|
||||
|
||||
# Test 2: CNN Prediction
|
||||
if not test_cnn_prediction():
|
||||
logger.error("❌ CNN prediction test failed - stopping")
|
||||
return False
|
||||
|
||||
logger.info("✅ All CNN integration tests passed!")
|
||||
logger.info("🎯 The CNN adapter should now work properly in the dashboard")
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
Reference in New Issue
Block a user