87 lines
3.1 KiB
Python
87 lines
3.1 KiB
Python
"""
|
|
Test Enhanced CNN Adapter
|
|
|
|
This script tests the EnhancedCNNAdapter with standardized input format.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
|
|
from core.standardized_data_provider import StandardizedDataProvider
|
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
from core.data_models import create_model_output
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_cnn_adapter():
|
|
"""Test the EnhancedCNNAdapter with standardized input format"""
|
|
try:
|
|
# Initialize data provider
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
timeframes = ['1s', '1m', '1h', '1d']
|
|
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
|
|
|
|
# Initialize CNN adapter
|
|
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
|
|
# Load best checkpoint if available
|
|
cnn_adapter.load_best_checkpoint()
|
|
|
|
# Get standardized input data
|
|
logger.info("Getting standardized input data...")
|
|
base_data = data_provider.get_base_data_input('ETH/USDT')
|
|
|
|
if base_data is None:
|
|
logger.error("Failed to get base data input")
|
|
return
|
|
|
|
# Make prediction
|
|
logger.info("Making prediction...")
|
|
model_output = cnn_adapter.predict(base_data)
|
|
|
|
# Log prediction
|
|
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
|
|
|
|
# Store model output
|
|
data_provider.store_model_output(model_output)
|
|
|
|
# Add training sample (simulated)
|
|
logger.info("Adding training sample...")
|
|
cnn_adapter.add_training_sample(base_data, 'BUY', 0.05)
|
|
|
|
# Train model
|
|
logger.info("Training model...")
|
|
metrics = cnn_adapter.train(epochs=1)
|
|
|
|
# Log training metrics
|
|
logger.info(f"Training metrics: {metrics}")
|
|
|
|
# Make another prediction
|
|
logger.info("Making another prediction...")
|
|
model_output = cnn_adapter.predict(base_data)
|
|
|
|
# Log prediction
|
|
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
|
|
|
|
# Test model output manager
|
|
logger.info("Testing model output manager...")
|
|
output_manager = data_provider.get_model_output_manager()
|
|
|
|
# Get current outputs
|
|
current_outputs = output_manager.get_all_current_outputs('ETH/USDT')
|
|
logger.info(f"Current outputs: {len(current_outputs)} models")
|
|
|
|
# Evaluate model performance
|
|
metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1')
|
|
logger.info(f"Performance metrics: {metrics}")
|
|
|
|
logger.info("Test completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in test: {e}", exc_info=True)
|
|
|
|
if __name__ == "__main__":
|
|
test_cnn_adapter() |