395 lines
15 KiB
Python
395 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Comprehensive Training Integration Tests
|
|
|
|
This module consolidates and improves test functionality from multiple test files:
|
|
- CNN training tests (from test_cnn_only.py, test_training.py)
|
|
- Model testing (from test_model.py)
|
|
- Chart data testing (from test_chart_data.py)
|
|
- Integration testing between components
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import logging
|
|
import time
|
|
import unittest
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import setup_logging, get_config
|
|
from core.data_provider import DataProvider
|
|
from training.cnn_trainer import CNNTrainer
|
|
from training.rl_trainer import RLTrainer
|
|
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TestDataProviders(unittest.TestCase):
|
|
"""Test suite for data provider functionality"""
|
|
|
|
def test_binance_historical_data(self):
|
|
"""Test Binance historical data fetching"""
|
|
logger.info("Testing Binance historical data fetch...")
|
|
|
|
try:
|
|
binance_data = BinanceHistoricalData()
|
|
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
|
|
|
|
self.assertIsNotNone(df, "Should fetch data successfully")
|
|
self.assertFalse(df.empty, "Data should not be empty")
|
|
self.assertGreater(len(df), 0, "Should have candles")
|
|
|
|
# Verify data structure
|
|
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
for col in required_columns:
|
|
self.assertIn(col, df.columns, f"Should have {col} column")
|
|
|
|
logger.info(f"✅ Successfully fetched {len(df)} candles")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Binance API test failed: {e}")
|
|
self.skipTest("Binance API not available")
|
|
|
|
def test_tick_storage(self):
|
|
"""Test TickStorage functionality"""
|
|
logger.info("Testing TickStorage data loading...")
|
|
|
|
try:
|
|
tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"])
|
|
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
|
|
|
|
if success:
|
|
# Check timeframes
|
|
for tf in ["1m", "5m", "1h"]:
|
|
candles = tick_storage.get_candles(tf)
|
|
logger.info(f" {tf}: {len(candles)} candles")
|
|
|
|
logger.info("✅ TickStorage working correctly")
|
|
return True
|
|
else:
|
|
self.skipTest("Could not load tick storage data")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"TickStorage test failed: {e}")
|
|
self.skipTest("TickStorage not available")
|
|
|
|
def test_chart_initialization(self):
|
|
"""Test RealTimeChart initialization"""
|
|
logger.info("Testing RealTimeChart initialization...")
|
|
|
|
try:
|
|
chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False)
|
|
|
|
# Test getting candles
|
|
candles_1m = chart.get_candles(60)
|
|
|
|
self.assertIsInstance(candles_1m, list, "Should return list of candles")
|
|
logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Chart initialization failed: {e}")
|
|
self.skipTest("Chart initialization not available")
|
|
|
|
class TestCNNTraining(unittest.TestCase):
|
|
"""Test suite for CNN training functionality"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
setup_logging()
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures"""
|
|
import shutil
|
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
|
|
def test_cnn_quick_training(self):
|
|
"""Test quick CNN training with small dataset"""
|
|
logger.info("Testing CNN quick training...")
|
|
|
|
try:
|
|
config = get_config()
|
|
|
|
# Test configuration
|
|
symbols = ['ETH/USDT']
|
|
timeframes = ['1m', '5m']
|
|
num_samples = 100 # Very small for testing
|
|
epochs = 1
|
|
batch_size = 16
|
|
|
|
# Override config for quick test
|
|
config._config['timeframes'] = timeframes
|
|
|
|
trainer = CNNTrainer(config)
|
|
trainer.batch_size = batch_size
|
|
trainer.epochs = epochs
|
|
|
|
# Train model
|
|
save_path = os.path.join(self.temp_dir, 'test_cnn.pt')
|
|
results = trainer.train(symbols, save_path=save_path, num_samples=num_samples)
|
|
|
|
# Verify results
|
|
self.assertIsInstance(results, dict, "Should return results dict")
|
|
self.assertIn('best_val_accuracy', results, "Should have accuracy metric")
|
|
self.assertIn('total_epochs', results, "Should have epoch count")
|
|
self.assertIn('training_time', results, "Should have training time")
|
|
|
|
# Verify model was saved
|
|
self.assertTrue(os.path.exists(save_path), "Model should be saved")
|
|
|
|
logger.info(f"✅ CNN training completed successfully")
|
|
logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}")
|
|
logger.info(f" Training time: {results['training_time']:.2f}s")
|
|
|
|
except Exception as e:
|
|
logger.error(f"CNN training test failed: {e}")
|
|
raise
|
|
finally:
|
|
if hasattr(trainer, 'close_tensorboard'):
|
|
trainer.close_tensorboard()
|
|
|
|
class TestRLTraining(unittest.TestCase):
|
|
"""Test suite for RL training functionality"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
setup_logging()
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures"""
|
|
import shutil
|
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
|
|
def test_rl_quick_training(self):
|
|
"""Test quick RL training with small dataset"""
|
|
logger.info("Testing RL quick training...")
|
|
|
|
try:
|
|
# Setup minimal configuration
|
|
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m'])
|
|
trainer = RLTrainer(data_provider)
|
|
|
|
# Configure for very quick test
|
|
trainer.num_episodes = 5
|
|
trainer.max_steps_per_episode = 50
|
|
trainer.evaluation_frequency = 3
|
|
trainer.save_frequency = 10 # Don't save during test
|
|
|
|
# Train
|
|
save_path = os.path.join(self.temp_dir, 'test_rl.pt')
|
|
results = trainer.train(save_path=save_path)
|
|
|
|
# Verify results
|
|
self.assertIsInstance(results, dict, "Should return results dict")
|
|
self.assertIn('total_episodes', results, "Should have episode count")
|
|
self.assertIn('best_reward', results, "Should have best reward")
|
|
self.assertIn('final_evaluation', results, "Should have final evaluation")
|
|
|
|
logger.info(f"✅ RL training completed successfully")
|
|
logger.info(f" Total episodes: {results['total_episodes']}")
|
|
logger.info(f" Best reward: {results['best_reward']:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"RL training test failed: {e}")
|
|
raise
|
|
|
|
class TestExtendedTraining(unittest.TestCase):
|
|
"""Test suite for extended training functionality (from test_model.py)"""
|
|
|
|
def test_metrics_tracking(self):
|
|
"""Test comprehensive metrics tracking functionality"""
|
|
logger.info("Testing extended metrics tracking...")
|
|
|
|
# Test metrics history structure
|
|
metrics_history = {
|
|
"epoch": [],
|
|
"train_loss": [],
|
|
"val_loss": [],
|
|
"train_acc": [],
|
|
"val_acc": [],
|
|
"train_pnl": [],
|
|
"val_pnl": [],
|
|
"train_win_rate": [],
|
|
"val_win_rate": [],
|
|
"signal_distribution": []
|
|
}
|
|
|
|
# Simulate adding metrics
|
|
for epoch in range(3):
|
|
metrics_history["epoch"].append(epoch + 1)
|
|
metrics_history["train_loss"].append(0.5 - epoch * 0.1)
|
|
metrics_history["val_loss"].append(0.6 - epoch * 0.1)
|
|
metrics_history["train_acc"].append(0.6 + epoch * 0.05)
|
|
metrics_history["val_acc"].append(0.55 + epoch * 0.05)
|
|
metrics_history["train_pnl"].append(epoch * 0.1)
|
|
metrics_history["val_pnl"].append(epoch * 0.08)
|
|
metrics_history["train_win_rate"].append(0.5 + epoch * 0.1)
|
|
metrics_history["val_win_rate"].append(0.45 + epoch * 0.1)
|
|
metrics_history["signal_distribution"].append({
|
|
"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4
|
|
})
|
|
|
|
# Verify structure
|
|
self.assertEqual(len(metrics_history["epoch"]), 3)
|
|
self.assertEqual(len(metrics_history["train_loss"]), 3)
|
|
self.assertEqual(len(metrics_history["signal_distribution"]), 3)
|
|
|
|
# Verify improvement
|
|
self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0])
|
|
self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0])
|
|
|
|
logger.info("✅ Metrics tracking test passed")
|
|
|
|
def test_signal_distribution_calculation(self):
|
|
"""Test signal distribution calculation"""
|
|
import numpy as np
|
|
|
|
# Mock predictions (SELL=0, HOLD=1, BUY=2)
|
|
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
|
|
|
|
buy_count = np.sum(predictions == 2)
|
|
sell_count = np.sum(predictions == 0)
|
|
hold_count = np.sum(predictions == 1)
|
|
total = len(predictions)
|
|
|
|
distribution = {
|
|
"BUY": buy_count / total,
|
|
"SELL": sell_count / total,
|
|
"HOLD": hold_count / total
|
|
}
|
|
|
|
# Verify calculations
|
|
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
|
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
|
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
|
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
|
|
|
logger.info("✅ Signal distribution calculation test passed")
|
|
|
|
class TestIntegration(unittest.TestCase):
|
|
"""Integration tests between components"""
|
|
|
|
def test_training_pipeline_integration(self):
|
|
"""Test that CNN and RL training can work together"""
|
|
logger.info("Testing training pipeline integration...")
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
try:
|
|
# Quick CNN training
|
|
config = get_config()
|
|
config._config['timeframes'] = ['1m']
|
|
|
|
cnn_trainer = CNNTrainer(config)
|
|
cnn_trainer.epochs = 1
|
|
cnn_trainer.batch_size = 8
|
|
|
|
cnn_path = os.path.join(temp_dir, 'test_cnn.pt')
|
|
cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50)
|
|
|
|
# Quick RL training
|
|
data_provider = DataProvider(['ETH/USDT'], ['1m'])
|
|
rl_trainer = RLTrainer(data_provider)
|
|
rl_trainer.num_episodes = 3
|
|
rl_trainer.max_steps_per_episode = 25
|
|
|
|
rl_path = os.path.join(temp_dir, 'test_rl.pt')
|
|
rl_results = rl_trainer.train(save_path=rl_path)
|
|
|
|
# Verify both trained successfully
|
|
self.assertIsInstance(cnn_results, dict)
|
|
self.assertIsInstance(rl_results, dict)
|
|
self.assertTrue(os.path.exists(cnn_path))
|
|
self.assertTrue(os.path.exists(rl_path))
|
|
|
|
logger.info("✅ Training pipeline integration test passed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Integration test failed: {e}")
|
|
raise
|
|
finally:
|
|
if 'cnn_trainer' in locals():
|
|
cnn_trainer.close_tensorboard()
|
|
|
|
def run_quick_tests():
|
|
"""Run only the quickest tests for fast validation"""
|
|
test_suites = [
|
|
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
|
|
]
|
|
|
|
combined_suite = unittest.TestSuite(test_suites)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(combined_suite)
|
|
|
|
return result.wasSuccessful()
|
|
|
|
def run_data_tests():
|
|
"""Run data provider tests"""
|
|
test_suites = [
|
|
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
|
|
]
|
|
|
|
combined_suite = unittest.TestSuite(test_suites)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(combined_suite)
|
|
|
|
return result.wasSuccessful()
|
|
|
|
def run_training_tests():
|
|
"""Run training tests (slower)"""
|
|
test_suites = [
|
|
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
|
|
]
|
|
|
|
combined_suite = unittest.TestSuite(test_suites)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(combined_suite)
|
|
|
|
return result.wasSuccessful()
|
|
|
|
def run_all_tests():
|
|
"""Run all test suites"""
|
|
test_suites = [
|
|
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestIntegration),
|
|
]
|
|
|
|
combined_suite = unittest.TestSuite(test_suites)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(combined_suite)
|
|
|
|
return result.wasSuccessful()
|
|
|
|
if __name__ == "__main__":
|
|
setup_logging()
|
|
logger.info("Running comprehensive training integration tests...")
|
|
|
|
if len(sys.argv) > 1:
|
|
test_type = sys.argv[1]
|
|
if test_type == "quick":
|
|
success = run_quick_tests()
|
|
elif test_type == "data":
|
|
success = run_data_tests()
|
|
elif test_type == "training":
|
|
success = run_training_tests()
|
|
else:
|
|
success = run_all_tests()
|
|
else:
|
|
success = run_all_tests()
|
|
|
|
if success:
|
|
logger.info("✅ All tests passed!")
|
|
sys.exit(0)
|
|
else:
|
|
logger.error("❌ Some tests failed!")
|
|
sys.exit(1) |