checkpoint manager
This commit is contained in:
87
test_enhanced_cnn_adapter.py
Normal file
87
test_enhanced_cnn_adapter.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Test Enhanced CNN Adapter
|
||||
|
||||
This script tests the EnhancedCNNAdapter with standardized input format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from core.data_models import create_model_output
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_cnn_adapter():
|
||||
"""Test the EnhancedCNNAdapter with standardized input format"""
|
||||
try:
|
||||
# Initialize data provider
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
|
||||
|
||||
# Initialize CNN adapter
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
# Load best checkpoint if available
|
||||
cnn_adapter.load_best_checkpoint()
|
||||
|
||||
# Get standardized input data
|
||||
logger.info("Getting standardized input data...")
|
||||
base_data = data_provider.get_base_data_input('ETH/USDT')
|
||||
|
||||
if base_data is None:
|
||||
logger.error("Failed to get base data input")
|
||||
return
|
||||
|
||||
# Make prediction
|
||||
logger.info("Making prediction...")
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
|
||||
# Log prediction
|
||||
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
|
||||
|
||||
# Store model output
|
||||
data_provider.store_model_output(model_output)
|
||||
|
||||
# Add training sample (simulated)
|
||||
logger.info("Adding training sample...")
|
||||
cnn_adapter.add_training_sample(base_data, 'BUY', 0.05)
|
||||
|
||||
# Train model
|
||||
logger.info("Training model...")
|
||||
metrics = cnn_adapter.train(epochs=1)
|
||||
|
||||
# Log training metrics
|
||||
logger.info(f"Training metrics: {metrics}")
|
||||
|
||||
# Make another prediction
|
||||
logger.info("Making another prediction...")
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
|
||||
# Log prediction
|
||||
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
|
||||
|
||||
# Test model output manager
|
||||
logger.info("Testing model output manager...")
|
||||
output_manager = data_provider.get_model_output_manager()
|
||||
|
||||
# Get current outputs
|
||||
current_outputs = output_manager.get_all_current_outputs('ETH/USDT')
|
||||
logger.info(f"Current outputs: {len(current_outputs)} models")
|
||||
|
||||
# Evaluate model performance
|
||||
metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1')
|
||||
logger.info(f"Performance metrics: {metrics}")
|
||||
|
||||
logger.info("Test completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test: {e}", exc_info=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cnn_adapter()
|
Reference in New Issue
Block a user