Files
gogo2/test_enhanced_cnn_adapter.py
2025-07-23 22:11:19 +03:00

87 lines
3.1 KiB
Python

"""
Test Enhanced CNN Adapter
This script tests the EnhancedCNNAdapter with standardized input format.
"""
import logging
import time
from datetime import datetime
from core.standardized_data_provider import StandardizedDataProvider
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import create_model_output
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_cnn_adapter():
"""Test the EnhancedCNNAdapter with standardized input format"""
try:
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Initialize CNN adapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Load best checkpoint if available
cnn_adapter.load_best_checkpoint()
# Get standardized input data
logger.info("Getting standardized input data...")
base_data = data_provider.get_base_data_input('ETH/USDT')
if base_data is None:
logger.error("Failed to get base data input")
return
# Make prediction
logger.info("Making prediction...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
# Store model output
data_provider.store_model_output(model_output)
# Add training sample (simulated)
logger.info("Adding training sample...")
cnn_adapter.add_training_sample(base_data, 'BUY', 0.05)
# Train model
logger.info("Training model...")
metrics = cnn_adapter.train(epochs=1)
# Log training metrics
logger.info(f"Training metrics: {metrics}")
# Make another prediction
logger.info("Making another prediction...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
# Test model output manager
logger.info("Testing model output manager...")
output_manager = data_provider.get_model_output_manager()
# Get current outputs
current_outputs = output_manager.get_all_current_outputs('ETH/USDT')
logger.info(f"Current outputs: {len(current_outputs)} models")
# Evaluate model performance
metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1')
logger.info(f"Performance metrics: {metrics}")
logger.info("Test completed successfully")
except Exception as e:
logger.error(f"Error in test: {e}", exc_info=True)
if __name__ == "__main__":
test_cnn_adapter()